mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-12 07:09:43 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1de2a4a828 | |||
| 55e752f0c2 | |||
| 847e74f628 | |||
| 33cad37054 | |||
| f55c6e89f0 | |||
| d602e8169c | |||
| 49baccdccb | |||
| 6a3d57031a | |||
| d74494d92b |
@@ -233,7 +233,7 @@ Under the hood, the `LeRobotDataset` format makes use of several ways to seriali
|
||||
|
||||
Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects:
|
||||
|
||||
```
|
||||
````
|
||||
dataset attributes:
|
||||
├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example:
|
||||
│ ├ observation.images.cam_high (VideoFrame):
|
||||
@@ -246,20 +246,30 @@ dataset attributes:
|
||||
│ ├ timestamp (float32): timestamp in the episode
|
||||
│ ├ next.done (bool): indicates the end of an episode ; True for the last frame in each episode
|
||||
│ └ index (int64): general index in the whole dataset
|
||||
├ episode_data_index: contains 2 tensors with the start and end indices of each episode
|
||||
│ ├ from (1D int64 tensor): first frame index for each episode — shape (num episodes,) starts with 0
|
||||
│ └ to: (1D int64 tensor): last frame index for each episode — shape (num episodes,)
|
||||
├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
|
||||
│ ├ observation.images.cam_high: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
|
||||
│ ...
|
||||
├ info: a dictionary of metadata on the dataset
|
||||
│ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with
|
||||
│ ├ fps (float): frame per second the dataset is recorded/synchronized to
|
||||
│ ├ video (bool): indicates if frames are encoded in mp4 video files to save space or stored as png files
|
||||
│ └ encoding (dict): if video, this documents the main options that were used with ffmpeg to encode the videos
|
||||
├ videos_dir (Path): where the mp4 videos or png images are stored/accessed
|
||||
└ camera_keys (list of string): the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`)
|
||||
```
|
||||
├ meta: a LeRobotDatasetMetadata object containing:
|
||||
│ ├ info: a dictionary of metadata on the dataset
|
||||
│ │ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with
|
||||
│ │ ├ fps (int): frame per second the dataset is recorded/synchronized to
|
||||
│ │ ├ features (dict): all features contained in the dataset with their shapes and types
|
||||
│ │ ├ total_episodes (int): total number of episodes in the dataset
|
||||
│ │ ├ total_frames (int): total number of frames in the dataset
|
||||
│ │ ├ robot_type (str): robot type used for recording
|
||||
│ │ ├ data_path (str): formattable string for the parquet files
|
||||
│ │ └ video_path (str): formattable string for the video files (if using videos)
|
||||
│ ├ episodes: a DataFrame containing episode metadata with columns:
|
||||
│ │ ├ episode_index (int): index of the episode
|
||||
│ │ ├ tasks (list): list of tasks for this episode
|
||||
│ │ ├ length (int): number of frames in this episode
|
||||
│ │ ├ dataset_from_index (int): start index of this episode in the dataset
|
||||
│ │ └ dataset_to_index (int): end index of this episode in the dataset
|
||||
│ ├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
|
||||
│ │ ├ observation.images.front_cam: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
|
||||
│ │ └ ...
|
||||
│ └ tasks: a DataFrame containing task information with task names as index and task_index as values
|
||||
├ root (Path): local directory where the dataset is stored
|
||||
├ image_transforms (Callable): optional image transformations to apply to visual modalities
|
||||
└ delta_timestamps (dict): optional delta timestamps for temporal queries
|
||||
decoding videos (e.g., 'pyav', 'torchcodec')
|
||||
|
||||
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
|
||||
|
||||
@@ -283,7 +293,7 @@ lerobot-eval \
|
||||
--eval.n_episodes=10 \
|
||||
--policy.use_amp=false \
|
||||
--policy.device=cuda
|
||||
```
|
||||
````
|
||||
|
||||
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||
|
||||
|
||||
@@ -108,7 +108,8 @@ def save_decoded_frames(
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
|
||||
return
|
||||
|
||||
@@ -265,7 +266,8 @@ def benchmark_encoding_decoding(
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
|
||||
num_pixels = width * height
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
|
||||
@@ -20,6 +20,12 @@
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
title: Using LeRobotDataset
|
||||
- local: porting_datasets_v3
|
||||
title: Porting Large Datasets
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: smolvla
|
||||
title: Finetune SmolVLA
|
||||
@@ -35,6 +41,8 @@
|
||||
title: Koch v1.1
|
||||
- local: lekiwi
|
||||
title: LeKiwi
|
||||
- local: reachy2
|
||||
title: Reachy 2
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: notebooks
|
||||
|
||||
+56
-382
@@ -4,13 +4,7 @@ In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient
|
||||
|
||||
HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process.
|
||||
|
||||
It combines three key ingredients:
|
||||
|
||||
1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point.
|
||||
|
||||
2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour.
|
||||
|
||||
3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
|
||||
It combines three key ingredients: 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
|
||||
|
||||
Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines.
|
||||
|
||||
@@ -62,243 +56,30 @@ pip install -e ".[hilserl]"
|
||||
|
||||
### Understanding Configuration
|
||||
|
||||
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/scripts/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||
The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/envs/configs.py`. Which is defined as:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
class GymManipulatorConfig:
|
||||
env: HILSerlRobotEnvConfig # Environment configuration (nested)
|
||||
dataset: DatasetConfig # Dataset recording/replay configuration (nested)
|
||||
mode: str | None = None # "record", "replay", or None (for training)
|
||||
device: str = "cpu" # Compute device
|
||||
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/robots`)
|
||||
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm
|
||||
processor: HILSerlProcessorConfig # Processing pipeline configuration (nested)
|
||||
name: str = "real_robot" # Environment name
|
||||
task: str | None = None # Task identifier
|
||||
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/teleoperators`)
|
||||
wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py`
|
||||
fps: int = 10 # Control frequency
|
||||
|
||||
# Nested processor configuration
|
||||
class HILSerlProcessorConfig:
|
||||
control_mode: str = "gamepad" # Control mode
|
||||
observation: ObservationConfig | None = None # Observation processing settings
|
||||
image_preprocessing: ImagePreprocessingConfig | None = None # Image crop/resize settings
|
||||
gripper: GripperConfig | None = None # Gripper control and penalty settings
|
||||
reset: ResetConfig | None = None # Environment reset and timing settings
|
||||
inverse_kinematics: InverseKinematicsConfig | None = None # IK processing settings
|
||||
reward_classifier: RewardClassifierConfig | None = None # Reward classifier settings
|
||||
max_gripper_pos: float | None = 100.0 # Maximum gripper position
|
||||
|
||||
# Sub-configuration classes
|
||||
class ObservationConfig:
|
||||
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
|
||||
add_current_to_observation: bool = False # Add motor currents to state
|
||||
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
|
||||
display_cameras: bool = False # Display camera feeds during execution
|
||||
|
||||
class ImagePreprocessingConfig:
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None # Image cropping parameters
|
||||
resize_size: tuple[int, int] | None = None # Target image size
|
||||
|
||||
class GripperConfig:
|
||||
use_gripper: bool = True # Enable gripper control
|
||||
gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage
|
||||
gripper_penalty_in_reward: bool = False # Include gripper penalty in reward
|
||||
|
||||
class ResetConfig:
|
||||
fixed_reset_joint_positions: Any | None = None # Joint positions for reset
|
||||
reset_time_s: float = 5.0 # Time to wait during reset
|
||||
control_time_s: float = 20.0 # Maximum episode duration
|
||||
terminate_on_success: bool = True # Whether to terminate episodes on success detection
|
||||
|
||||
class InverseKinematicsConfig:
|
||||
urdf_path: str | None = None # Path to robot URDF file
|
||||
target_frame_name: str | None = None # End-effector frame name
|
||||
end_effector_bounds: dict[str, list[float]] | None = None # EE workspace bounds
|
||||
end_effector_step_sizes: dict[str, float] | None = None # EE step sizes per axis
|
||||
|
||||
class RewardClassifierConfig:
|
||||
pretrained_path: str | None = None # Path to pretrained reward classifier
|
||||
success_threshold: float = 0.5 # Success detection threshold
|
||||
success_reward: float = 1.0 # Reward value for successful episodes
|
||||
|
||||
# Dataset configuration
|
||||
class DatasetConfig:
|
||||
repo_id: str # LeRobot dataset repository ID
|
||||
task: str # Task identifier
|
||||
root: str | None = None # Local dataset root directory
|
||||
num_episodes_to_record: int = 5 # Number of episodes for recording
|
||||
replay_episode: int | None = None # Episode index for replay
|
||||
push_to_hub: bool = False # Whether to push datasets to Hub
|
||||
name: str = "real_robot" # Environment name
|
||||
mode: str = None # "record", "replay", or None (for training)
|
||||
repo_id: str | None = None # LeRobot dataset repository ID
|
||||
dataset_root: str | None = None # Local dataset root (optional)
|
||||
task: str = "" # Task identifier
|
||||
num_episodes: int = 10 # Number of episodes for recording
|
||||
episode: int = 0 # episode index for replay
|
||||
device: str = "cuda" # Compute device
|
||||
push_to_hub: bool = True # Whether to push the recorded datasets to Hub
|
||||
pretrained_policy_name_or_path: str | None = None # For policy loading
|
||||
reward_classifier_pretrained_path: str | None = None # For reward model
|
||||
number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Processor Pipeline Architecture
|
||||
|
||||
HIL-SERL uses a modular processor pipeline architecture that processes robot observations and actions through a series of composable steps. The pipeline is divided into two main components:
|
||||
|
||||
#### Environment Processor Pipeline
|
||||
|
||||
The environment processor (`env_processor`) handles incoming observations and environment state:
|
||||
|
||||
1. **VanillaObservationProcessorStep**: Converts raw robot observations into standardized format
|
||||
2. **JointVelocityProcessorStep** (optional): Adds joint velocity information to observations
|
||||
3. **MotorCurrentProcessorStep** (optional): Adds motor current readings to observations
|
||||
4. **ForwardKinematicsJointsToEE** (optional): Computes end-effector pose from joint positions
|
||||
5. **ImageCropResizeProcessorStep** (optional): Crops and resizes camera images
|
||||
6. **TimeLimitProcessorStep** (optional): Enforces episode time limits
|
||||
7. **GripperPenaltyProcessorStep** (optional): Applies penalties for inappropriate gripper usage
|
||||
8. **RewardClassifierProcessorStep** (optional): Automated reward detection using vision models
|
||||
9. **AddBatchDimensionProcessorStep**: Converts data to batch format for neural network processing
|
||||
10. **DeviceProcessorStep**: Moves data to the specified compute device (CPU/GPU)
|
||||
|
||||
#### Action Processor Pipeline
|
||||
|
||||
The action processor (`action_processor`) handles outgoing actions and human interventions:
|
||||
|
||||
1. **AddTeleopActionAsComplimentaryDataStep**: Captures teleoperator actions for logging
|
||||
2. **AddTeleopEventsAsInfoStep**: Records intervention events and episode control signals
|
||||
3. **AddRobotObservationAsComplimentaryData**: Stores raw robot state for processing
|
||||
4. **InterventionActionProcessorStep**: Handles human interventions and episode termination
|
||||
5. **Inverse Kinematics Pipeline** (when enabled):
|
||||
- **MapDeltaActionToRobotActionStep**: Converts delta actions to robot action format
|
||||
- **EEReferenceAndDelta**: Computes end-effector reference and delta movements
|
||||
- **EEBoundsAndSafety**: Enforces workspace safety bounds
|
||||
- **InverseKinematicsEEToJoints**: Converts end-effector actions to joint targets
|
||||
- **GripperVelocityToJoint**: Handles gripper control commands
|
||||
|
||||
#### Configuration Examples
|
||||
|
||||
**Basic Observation Processing**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true,
|
||||
"add_current_to_observation": false,
|
||||
"display_cameras": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Image Processing**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {
|
||||
"observation.images.front": [180, 250, 120, 150],
|
||||
"observation.images.side": [180, 207, 180, 200]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Inverse Kinematics Setup**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"inverse_kinematics": {
|
||||
"urdf_path": "path/to/robot.urdf",
|
||||
"target_frame_name": "end_effector",
|
||||
"end_effector_bounds": {
|
||||
"min": [0.16, -0.08, 0.03],
|
||||
"max": [0.24, 0.2, 0.1]
|
||||
},
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Observation Processing
|
||||
|
||||
The HIL-SERL framework supports additional observation processing features that can improve policy learning:
|
||||
|
||||
#### Joint Velocity Processing
|
||||
|
||||
Enable joint velocity estimation to provide the policy with motion information:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This processor:
|
||||
|
||||
- Estimates joint velocities using finite differences between consecutive joint position readings
|
||||
- Adds velocity information to the observation state vector
|
||||
- Useful for policies that need motion awareness for dynamic tasks
|
||||
|
||||
#### Motor Current Processing
|
||||
|
||||
Monitor motor currents to detect contact forces and load conditions:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_current_to_observation": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This processor:
|
||||
|
||||
- Reads motor current values from the robot's control system
|
||||
- Adds current measurements to the observation state vector
|
||||
- Helps detect contact events, object weights, and mechanical resistance
|
||||
- Useful for contact-rich manipulation tasks
|
||||
|
||||
#### Combined Observation Processing
|
||||
|
||||
You can enable multiple observation processing features simultaneously:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true,
|
||||
"add_current_to_observation": true,
|
||||
"add_ee_pose_to_observation": false,
|
||||
"display_cameras": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Note**: Enabling additional observation features increases the state space dimensionality, which may require adjusting your policy network architecture and potentially collecting more training data.
|
||||
|
||||
### Finding Robot Workspace Bounds
|
||||
|
||||
Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot.
|
||||
@@ -349,56 +130,22 @@ With the bounds defined, you can safely collect demonstrations for training. Tra
|
||||
|
||||
Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)):
|
||||
|
||||
1. Set `mode` to `"record"` at the root level
|
||||
2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name")
|
||||
3. Set `num_episodes_to_record` in the `dataset` section to the number of demonstrations you want to collect
|
||||
4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later)
|
||||
5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section
|
||||
1. Set `mode` to `"record"`
|
||||
2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name")
|
||||
3. Set `num_episodes` to the number of demonstrations you want to collect
|
||||
4. Set `crop_params_dict` to `null` initially (we'll determine crops later)
|
||||
5. Configure `robot`, `cameras`, and other hardware settings
|
||||
|
||||
Example configuration section:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "real_robot",
|
||||
"fps": 10,
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"observation": {
|
||||
"display_cameras": false
|
||||
},
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {},
|
||||
"resize_size": [128, 128]
|
||||
},
|
||||
"gripper": {
|
||||
"use_gripper": true,
|
||||
"gripper_penalty": 0.0
|
||||
},
|
||||
"reset": {
|
||||
"reset_time_s": 5.0,
|
||||
"control_time_s": 20.0
|
||||
}
|
||||
},
|
||||
"robot": {
|
||||
// ... robot configuration ...
|
||||
},
|
||||
"teleop": {
|
||||
// ... teleoperator configuration ...
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "username/pick_lift_cube",
|
||||
"root": null,
|
||||
"task": "pick_and_lift",
|
||||
"num_episodes_to_record": 15,
|
||||
"replay_episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cpu"
|
||||
}
|
||||
"mode": "record",
|
||||
"repo_id": "username/pick_lift_cube",
|
||||
"dataset_root": null,
|
||||
"task": "pick_and_lift",
|
||||
"num_episodes": 15,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
```
|
||||
|
||||
### Using a Teleoperation Device
|
||||
@@ -444,20 +191,10 @@ The gamepad provides a very convenient way to control the robot and the episode
|
||||
To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file.
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"teleop": {
|
||||
"type": "gamepad",
|
||||
"use_gripper": true
|
||||
},
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"gripper": {
|
||||
"type": "gamepad",
|
||||
"use_gripper": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
@@ -479,21 +216,11 @@ The SO101 leader arm has reduced gears that allows it to move and track the foll
|
||||
To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file.
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"teleop": {
|
||||
"type": "so101_leader",
|
||||
"port": "/dev/tty.usbmodem585A0077921",
|
||||
"use_degrees": true
|
||||
"type": "so101_leader",
|
||||
"port": "/dev/tty.usbmodem585A0077921", # check your port number
|
||||
"use_degrees": true
|
||||
},
|
||||
"processor": {
|
||||
"control_mode": "leader",
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure.
|
||||
@@ -524,7 +251,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/e
|
||||
|
||||
During recording:
|
||||
|
||||
1. The robot will reset to the initial position defined in the configuration file `env.processor.reset.fixed_reset_joint_positions`
|
||||
1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions`
|
||||
2. Complete the task successfully
|
||||
3. The episode ends with a reward of 1 when you press the "success" button
|
||||
4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0
|
||||
@@ -583,19 +310,11 @@ observation.images.front: [180, 250, 120, 150]
|
||||
Add these crop parameters to your training configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {
|
||||
"observation.images.side": [180, 207, 180, 200],
|
||||
"observation.images.front": [180, 250, 120, 150]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"crop_params_dict": {
|
||||
"observation.images.side": [180, 207, 180, 200],
|
||||
"observation.images.front": [180, 250, 120, 150]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
```
|
||||
|
||||
**Recommended image resolution**
|
||||
@@ -624,52 +343,26 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/r
|
||||
|
||||
**Key Parameters for Data Collection**
|
||||
|
||||
- **mode**: set it to `"record"` to collect a dataset (at root level)
|
||||
- **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
|
||||
- **dataset.num_episodes_to_record**: Number of episodes to record
|
||||
- **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`)
|
||||
- **env.fps**: Number of frames per second to record
|
||||
- **dataset.push_to_hub**: Whether to push the dataset to the hub
|
||||
- **mode**: set it to `"record"` to collect a dataset
|
||||
- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
|
||||
- **num_episodes**: Number of episodes to record
|
||||
- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected
|
||||
- **fps**: Number of frames per second to record
|
||||
- **push_to_hub**: Whether to push the dataset to the hub
|
||||
|
||||
The `env.processor.reset.terminate_on_success` parameter allows you to control episode termination behavior. When set to `false`, episodes will continue even after success is detected, allowing you to collect more positive examples with the reward=1 label. This is crucial for training reward classifiers as it provides more success state examples in your dataset. When set to `true` (default), episodes terminate immediately upon success detection.
|
||||
|
||||
**Important**: For reward classifier training, set `terminate_on_success: false` to collect sufficient positive examples. For regular HIL-SERL training, keep it as `true` to enable automatic episode termination when the task is completed successfully.
|
||||
The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier.
|
||||
|
||||
Example configuration section for data collection:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "real_robot",
|
||||
"fps": 10,
|
||||
"processor": {
|
||||
"reset": {
|
||||
"reset_time_s": 5.0,
|
||||
"control_time_s": 20.0,
|
||||
"terminate_on_success": false
|
||||
},
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
},
|
||||
"robot": {
|
||||
// ... robot configuration ...
|
||||
},
|
||||
"teleop": {
|
||||
// ... teleoperator configuration ...
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"task": "reward_classifier_task",
|
||||
"num_episodes_to_record": 20,
|
||||
"replay_episode": null,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cpu"
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"num_episodes": 20,
|
||||
"push_to_hub": true,
|
||||
"fps": 10,
|
||||
"number_of_steps_after_success": 15
|
||||
}
|
||||
```
|
||||
|
||||
@@ -728,17 +421,9 @@ To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
config = GymManipulatorConfig(
|
||||
env=HILSerlRobotEnvConfig(
|
||||
processor=HILSerlProcessorConfig(
|
||||
reward_classifier=RewardClassifierConfig(
|
||||
pretrained_path="path_to_your_pretrained_trained_model"
|
||||
)
|
||||
),
|
||||
# Other environment parameters
|
||||
),
|
||||
dataset=DatasetConfig(...),
|
||||
mode=None # For training
|
||||
env_config = HILSerlRobotEnvConfig(
|
||||
reward_classifier_pretrained_path="path_to_your_pretrained_trained_model",
|
||||
# Other environment parameters
|
||||
)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
@@ -747,18 +432,7 @@ or set the argument in the json config file.
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"reward_classifier": {
|
||||
"pretrained_path": "path_to_your_pretrained_model",
|
||||
"success_threshold": 0.7,
|
||||
"success_reward": 1.0
|
||||
},
|
||||
"reset": {
|
||||
"terminate_on_success": true
|
||||
}
|
||||
}
|
||||
}
|
||||
"reward_classifier_pretrained_path": "path_to_your_pretrained_model"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
+30
-56
@@ -32,12 +32,9 @@ To use `gym_hil` with LeRobot, you need to create a configuration file. An examp
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"type": "hil",
|
||||
"name": "franka_sim",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
@@ -48,40 +45,28 @@ Available tasks:
|
||||
- `PandaPickCubeGamepad-v0`: With gamepad control
|
||||
- `PandaPickCubeKeyboard-v0`: With keyboard control
|
||||
|
||||
### Processor Configuration
|
||||
### Gym Wrappers Configuration
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"gripper": {
|
||||
"use_gripper": true,
|
||||
"gripper_penalty": -0.02
|
||||
},
|
||||
"reset": {
|
||||
"control_time_s": 15.0,
|
||||
"fixed_reset_joint_positions": [
|
||||
0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785
|
||||
]
|
||||
},
|
||||
"inverse_kinematics": {
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.025,
|
||||
"y": 0.025,
|
||||
"z": 0.025
|
||||
}
|
||||
}
|
||||
"wrapper": {
|
||||
"gripper_penalty": -0.02,
|
||||
"control_time_s": 15.0,
|
||||
"use_gripper": true,
|
||||
"fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785],
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.025,
|
||||
"y": 0.025,
|
||||
"z": 0.025
|
||||
},
|
||||
"control_mode": "gamepad"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Important parameters:
|
||||
|
||||
- `gripper.gripper_penalty`: Penalty for excessive gripper movement
|
||||
- `gripper.use_gripper`: Whether to enable gripper control
|
||||
- `inverse_kinematics.end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
|
||||
- `gripper_penalty`: Penalty for excessive gripper movement
|
||||
- `use_gripper`: Whether to enable gripper control
|
||||
- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
|
||||
- `control_mode`: Set to `"gamepad"` to use a gamepad controller
|
||||
|
||||
## Running with HIL RL of LeRobot
|
||||
@@ -90,50 +75,39 @@ Important parameters:
|
||||
|
||||
To run the environment, set mode to null:
|
||||
|
||||
```bash
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Recording a Dataset
|
||||
|
||||
To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0"
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "username/sim_dataset",
|
||||
"root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes_to_record": 10,
|
||||
"replay_episode": null,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record"
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Training a Policy
|
||||
|
||||
To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers:
|
||||
|
||||
```bash
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
python -m lerobot.scripts.rl.actor --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
In a different terminal, run the learner server:
|
||||
|
||||
```bash
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
python -m lerobot.scripts.rl.learner --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots.
|
||||
|
||||
|
||||
@@ -519,14 +519,11 @@ from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.policies.factory import make_processor
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
@@ -538,7 +535,7 @@ robot_config = SO100FollowerConfig(
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Initialize the policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
policy = ACTPolicy.from_pretrained("<hf_username>/<my_policy_repo_id>")
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
@@ -547,7 +544,7 @@ dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
repo_id="<hf_username>/eval_<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
@@ -562,12 +559,6 @@ _init_rerun(session_name="recording")
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
@@ -577,8 +568,6 @@ for episode_idx in range(NUM_EPISODES):
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
|
||||
+5
-53
@@ -24,36 +24,11 @@ pip install -e ".[hilserl]"
|
||||
|
||||
To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_gym_hil_il.json).
|
||||
|
||||
To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection:
|
||||
To teleoperate and collect a dataset, we need to modify this config file and you should add your `repo_id` here: `"repo_id": "il_gym",` and `"num_episodes": 30,` and make sure you set `mode` to `record`, "mode": "record".
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_gym",
|
||||
"root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes_to_record": 30,
|
||||
"replay_episode": null,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
If you do not have a Nvidia GPU also change `"device": "cuda"` parameter in the config file (for example to `mps` for MacOS).
|
||||
|
||||
Key configuration points:
|
||||
|
||||
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
|
||||
- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes
|
||||
- Ensure `mode` is set to `"record"`
|
||||
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
|
||||
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
|
||||
By default the config file assumes you use a controller. To use your keyboard please change the envoirment specified at `"task"` in the config file and set it to `"PandaPickCubeKeyboard-v0"`.
|
||||
|
||||
Then we can run this command to start:
|
||||
|
||||
@@ -165,32 +140,9 @@ huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
|
||||
|
||||
## Evaluate your policy in Sim
|
||||
|
||||
To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
|
||||
To evaluate your policy we have to use the config file that can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
|
||||
|
||||
Here's an example evaluation configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_sim_dataset",
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube"
|
||||
},
|
||||
"pretrained_policy_name_or_path": "your_username/il_sim_model",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
Make sure to replace:
|
||||
|
||||
- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`)
|
||||
- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`)
|
||||
Make sure to replace the `repo_id` with the dataset you trained on, for example `pepijn223/il_sim_dataset` and replace the `pretrained_policy_name_or_path` with your model id, for example `pepijn223/il_sim_model`
|
||||
|
||||
Then you can run this command to visualize your trained policy
|
||||
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
# LeRobotDataset v3.0
|
||||
|
||||
`LeRobotDataset v3.0` is a standardized format for robot learning data. It provides unified access to multi-modal time-series data, sensorimotor signals and multi‑camera video, as well as rich metadata for indexing, search, and visualization on the Hugging Face Hub.
|
||||
|
||||
This docs will guide you to:
|
||||
|
||||
- Understand the v3.0 design and directory layout
|
||||
- Record a dataset and push it to the Hub
|
||||
- Load datasets for training with `LeRobotDataset`
|
||||
- Stream datasets without downloading using `StreamingLeRobotDataset`
|
||||
- Migrate existing `v2.1` datasets to `v3.0`
|
||||
|
||||
## What’s new in `v3`
|
||||
|
||||
- **File-based storage**: Many episodes per Parquet/MP4 file (v2 used one file per episode).
|
||||
- **Relational metadata**: Episode boundaries and lookups are resolved through metadata, not filenames.
|
||||
- **Hub-native streaming**: Consume datasets directly from the Hub with `StreamingLeRobotDataset`.
|
||||
- **Lower file-system pressure**: Fewer, larger files ⇒ faster initialization and fewer issues at scale.
|
||||
- **Unified organization**: Clean directory layout with consistent path templates across data and videos.
|
||||
|
||||
## Installation
|
||||
|
||||
`LeRobotDataset v3.0` will be included in `lerobot >= 0.4.0`.
|
||||
|
||||
Until that stable release, you can use the main branch by following the [build from source instructions](./installation#from-source).
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Run the command below to record a dataset with the SO-101 and push to the Hub:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_awesome_leader_arm \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/record-test \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube"
|
||||
```
|
||||
|
||||
See the [recording guide](./il_robots#record-a-dataset) for more details.
|
||||
|
||||
## Format design
|
||||
|
||||
A core v3 principle is **decoupling storage from the user API**: data is stored efficiently (few large files), while the public API exposes intuitive episode-level access.
|
||||
|
||||
`v3` has three pillars:
|
||||
|
||||
1. **Tabular data**: Low‑dimensional, high‑frequency signals (states, actions, timestamps) stored in **Apache Parquet**. Access is memory‑mapped or streamed via the `datasets` stack.
|
||||
2. **Visual data**: Camera frames concatenated and encoded into **MP4**. Frames from the same episode are grouped; videos are sharded per camera for practical sizes.
|
||||
3. **Metadata**: JSON/Parquet records describing schema (feature names, dtypes, shapes), frame rates, normalization stats, and **episode segmentation** (start/end offsets into shared Parquet/MP4 files).
|
||||
|
||||
> To scale to millions of episodes, tabular rows and video frames from multiple episodes are **concatenated** into larger files. Episode‑specific views are reconstructed **via metadata**, not file boundaries.
|
||||
|
||||
<div style="display:flex; justify-content:center; gap:12px; flex-wrap:wrap;">
|
||||
<figure style="margin:0; text-align:center;">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobotdataset-v3/asset1datasetv3.png"
|
||||
alt="LeRobotDataset v3 diagram"
|
||||
width="220"
|
||||
/>
|
||||
<figcaption style="font-size:0.9em; color:#666;">
|
||||
From episode‑based to file‑based datasets
|
||||
</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
### Directory layout (simplified)
|
||||
|
||||
- **`meta/info.json`**: canonical schema (features, shapes/dtypes), FPS, codebase version, and **path templates** to locate data/video shards.
|
||||
- **`meta/stats.json`**: global feature statistics (mean/std/min/max) used for normalization; exposed as `dataset.meta.stats`.
|
||||
- **`meta/tasks.jsonl`**: natural‑language task descriptions mapped to integer IDs for task‑conditioned policies.
|
||||
- **`meta/episodes/`**: per‑episode records (lengths, tasks, offsets) stored as **chunked Parquet** for scalability.
|
||||
- **`data/`**: frame‑by‑frame **Parquet** shards; each file typically contains **many episodes**.
|
||||
- **`videos/`**: **MP4** shards per camera; each file typically contains **many episodes**.
|
||||
|
||||
## Load a dataset for training
|
||||
|
||||
`LeRobotDataset` returns Python dictionaries of PyTorch tensors and integrates with `torch.utils.data.DataLoader`. Here is a code example showing its use:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
repo_id = "yaak-ai/L2D-v3"
|
||||
|
||||
# 1) Load from the Hub (cached locally)
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# 2) Random access by index
|
||||
sample = dataset[100]
|
||||
print(sample)
|
||||
# {
|
||||
# 'observation.state': tensor([...]),
|
||||
# 'action': tensor([...]),
|
||||
# 'observation.images.front_left': tensor([C, H, W]),
|
||||
# 'timestamp': tensor(1.234),
|
||||
# ...
|
||||
# }
|
||||
|
||||
# 3) Temporal windows via delta_timestamps (seconds relative to t)
|
||||
delta_timestamps = {
|
||||
"observation.images.front_left": [-0.2, -0.1, 0.0] # 0.2s and 0.1s before current frame
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Accessing an index now returns a stack for the specified key(s)
|
||||
sample = dataset[100]
|
||||
print(sample["observation.images.front_left"].shape) # [T, C, H, W], where T=3
|
||||
|
||||
# 4) Wrap with a DataLoader for training
|
||||
batch_size = 16
|
||||
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
for batch in data_loader:
|
||||
observations = batch["observation.state"].to(device)
|
||||
actions = batch["action"].to(device)
|
||||
images = batch["observation.images.front_left"].to(device)
|
||||
# model.forward(batch)
|
||||
```
|
||||
|
||||
## Stream a dataset (no downloads)
|
||||
|
||||
Use `StreamingLeRobotDataset` to iterate directly from the Hub without local copies. This allows to stream large datasets without the need to downloading them onto disk or loading them onto memory, and is a key feature of the new dataset format.
|
||||
|
||||
```python
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
repo_id = "yaak-ai/L2D-v3"
|
||||
dataset = StreamingLeRobotDataset(repo_id) # streams directly from the Hub
|
||||
```
|
||||
|
||||
<div style="display:flex; justify-content:center; gap:12px; flex-wrap:wrap;">
|
||||
<figure style="margin:0; text-align:center;">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobotdataset-v3/streaming-lerobot.png"
|
||||
alt="StreamingLeRobotDataset"
|
||||
width="520"
|
||||
/>
|
||||
<figcaption style="font-size:0.9em; color:#666;">
|
||||
Stream directly from the Hub for on‑the‑fly training.
|
||||
</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
## Migrate `v2.1` → `v3.0`
|
||||
|
||||
A converter aggregates per‑episode files into larger shards and writes episode offsets/metadata. Convert your dataset using the instructions below.
|
||||
|
||||
```bash
|
||||
# Pre-release build with v3 support:
|
||||
pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
|
||||
|
||||
# Convert an existing v2.1 dataset hosted on the Hub:
|
||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
||||
```
|
||||
|
||||
**What it does**
|
||||
|
||||
- Aggregates parquet files: `episode-0000.parquet`, `episode-0001.parquet`, … → **`file-0000.parquet`**, …
|
||||
- Aggregates mp4 files: `episode-0000.mp4`, `episode-0001.mp4`, … → **`file-0000.mp4`**, …
|
||||
- Updates `meta/episodes/*` (chunked Parquet) with per‑episode lengths, tasks, and byte/frame offsets.
|
||||
@@ -0,0 +1,321 @@
|
||||
# Porting Large Datasets to LeRobot Dataset v3.0
|
||||
|
||||
This tutorial explains how to port large-scale robotic datasets to the LeRobot Dataset v3.0 format. We'll use the **DROID 1.0.1** dataset as our primary example, which demonstrates handling multi-terabyte datasets with thousands of shards across SLURM clusters.
|
||||
|
||||
## File Organization: v2.1 vs v3.0
|
||||
|
||||
Dataset v3.0 fundamentally changes how data is organized and stored:
|
||||
|
||||
**v2.1 Structure (Episode-based)**:
|
||||
|
||||
```
|
||||
dataset/
|
||||
├── data/chunk-000/episode_000000.parquet
|
||||
├── data/chunk-000/episode_000001.parquet
|
||||
├── videos/chunk-000/camera/episode_000000.mp4
|
||||
└── meta/episodes.jsonl
|
||||
```
|
||||
|
||||
**v3.0 Structure (File-based)**:
|
||||
|
||||
```
|
||||
dataset/
|
||||
├── data/chunk-000/file-000.parquet # Multiple episodes per file
|
||||
├── videos/camera/chunk-000/file-000.mp4 # Consolidated video chunks
|
||||
└── meta/episodes/chunk-000/file-000.parquet # Structured metadata
|
||||
```
|
||||
|
||||
This transition from individual episode files to file-based chunks dramatically improves performance and reduces storage overhead.
|
||||
|
||||
## What's New in Dataset v3.0
|
||||
|
||||
Dataset v3.0 introduces significant improvements for handling large datasets:
|
||||
|
||||
### 🏗️ **Enhanced File Organization**
|
||||
|
||||
- **File-based structure**: Episodes are now grouped into chunked files rather than individual episode files
|
||||
- **Configurable file sizes**: for data and video files
|
||||
- **Improved storage efficiency**: Better compression and reduced overhead
|
||||
|
||||
### 📊 **Modern Metadata Management**
|
||||
|
||||
- **Parquet-based metadata**: Replaced JSON Lines with efficient parquet format
|
||||
- **Structured episode access**: Direct pandas DataFrame access via `dataset.meta.episodes`
|
||||
- **Per-episode statistics**: Enhanced statistics tracking at episode level
|
||||
|
||||
### 🚀 **Performance Enhancements**
|
||||
|
||||
- **Memory-mapped access**: Improved RAM usage through PyArrow memory mapping
|
||||
- **Faster loading**: Significantly reduced dataset initialization time
|
||||
- **Better scalability**: Designed for datasets with millions of episodes
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before porting large datasets, ensure you have:
|
||||
|
||||
- **LeRobot installed** with v3.0 support. Follow our [Installation Guide](./installation).
|
||||
- **Sufficient storage**: Raw datasets can be very large (e.g., DROID requires 2TB)
|
||||
- **Cluster access** (recommended for large datasets): SLURM or similar job scheduler
|
||||
- **Dataset-specific dependencies**: For DROID, you'll need TensorFlow Dataset utilities
|
||||
|
||||
## Understanding the DROID Dataset
|
||||
|
||||
[DROID 1.0.1](https://droid-dataset.github.io/droid/the-droid-dataset) is an excellent example of a large-scale robotic dataset:
|
||||
|
||||
- **Size**: 1.7TB (RLDS format), 8.7TB (raw data)
|
||||
- **Structure**: 2048 pre-defined TensorFlow dataset shards
|
||||
- **Content**: 76,000+ robot manipulation trajectories from Franka Emika Panda robots
|
||||
- **Scope**: Real-world manipulation tasks across multiple environments and objects
|
||||
- **Format**: Originally in TensorFlow Records/RLDS format, requiring conversion to LeRobot format
|
||||
- **Hosting**: Google Cloud Storage with public access via `gsutil`
|
||||
|
||||
The dataset contains diverse manipulation demonstrations with:
|
||||
|
||||
- Multiple camera views (wrist camera, exterior cameras)
|
||||
- Natural language task descriptions
|
||||
- Robot proprioceptive state and actions
|
||||
- Success/failure annotations
|
||||
|
||||
### DROID Features Schema
|
||||
|
||||
```python
|
||||
DROID_FEATURES = {
|
||||
# Episode markers
|
||||
"is_first": {"dtype": "bool", "shape": (1,)},
|
||||
"is_last": {"dtype": "bool", "shape": (1,)},
|
||||
"is_terminal": {"dtype": "bool", "shape": (1,)},
|
||||
|
||||
# Language instructions
|
||||
"language_instruction": {"dtype": "string", "shape": (1,)},
|
||||
"language_instruction_2": {"dtype": "string", "shape": (1,)},
|
||||
"language_instruction_3": {"dtype": "string", "shape": (1,)},
|
||||
|
||||
# Robot state
|
||||
"observation.state.gripper_position": {"dtype": "float32", "shape": (1,)},
|
||||
"observation.state.cartesian_position": {"dtype": "float32", "shape": (6,)},
|
||||
"observation.state.joint_position": {"dtype": "float32", "shape": (7,)},
|
||||
|
||||
# Camera observations
|
||||
"observation.images.wrist_left": {"dtype": "image"},
|
||||
"observation.images.exterior_1_left": {"dtype": "image"},
|
||||
"observation.images.exterior_2_left": {"dtype": "image"},
|
||||
|
||||
# Actions
|
||||
"action.gripper_position": {"dtype": "float32", "shape": (1,)},
|
||||
"action.cartesian_position": {"dtype": "float32", "shape": (6,)},
|
||||
"action.joint_position": {"dtype": "float32", "shape": (7,)},
|
||||
|
||||
# Standard LeRobot format
|
||||
"observation.state": {"dtype": "float32", "shape": (8,)}, # joints + gripper
|
||||
"action": {"dtype": "float32", "shape": (8,)}, # joints + gripper
|
||||
}
|
||||
```
|
||||
|
||||
## Approach 1: Single Computer Porting
|
||||
|
||||
### Step 1: Install Dependencies
|
||||
|
||||
For DROID specifically:
|
||||
|
||||
```bash
|
||||
pip install tensorflow
|
||||
pip install tensorflow_datasets
|
||||
```
|
||||
|
||||
For other datasets, install the appropriate readers for your source format.
|
||||
|
||||
### Step 2: Download Raw Data
|
||||
|
||||
Download DROID from Google Cloud Storage using `gsutil`:
|
||||
|
||||
```bash
|
||||
# Install Google Cloud SDK if not already installed
|
||||
# https://cloud.google.com/sdk/docs/install
|
||||
|
||||
# Download the full RLDS dataset (1.7TB)
|
||||
gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 /your/data/
|
||||
|
||||
# Or download just the 100-episode sample (2GB) for testing
|
||||
gsutil -m cp -r gs://gresearch/robotics/droid_100 /your/data/
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Large datasets require substantial time and storage:
|
||||
>
|
||||
> - **Full DROID (1.7TB)**: Several days to download depending on bandwidth
|
||||
> - **Processing time**: 7+ days for local porting of full dataset
|
||||
> - **Upload time**: 3+ days to push to Hugging Face Hub
|
||||
> - **Local storage**: ~400GB for processed LeRobot format
|
||||
|
||||
### Step 3: Port the Dataset
|
||||
|
||||
```bash
|
||||
python examples/port_datasets/port_droid.py \
|
||||
--raw-dir /your/data/droid/1.0.1 \
|
||||
--repo-id your_id/droid_1.0.1 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
### Development and Testing
|
||||
|
||||
For development, you can port a single shard:
|
||||
|
||||
```bash
|
||||
python examples/port_datasets/port_droid.py \
|
||||
--raw-dir /your/data/droid/1.0.1 \
|
||||
--repo-id your_id/droid_1.0.1_test \
|
||||
--num-shards 2048 \
|
||||
--shard-index 0
|
||||
```
|
||||
|
||||
This approach works for smaller datasets or testing, but large datasets require cluster computing.
|
||||
|
||||
## Approach 2: SLURM Cluster Porting (Recommended)
|
||||
|
||||
For large datasets like DROID, parallel processing across multiple nodes dramatically reduces processing time.
|
||||
|
||||
### Step 1: Install Cluster Dependencies
|
||||
|
||||
```bash
|
||||
pip install datatrove # Hugging Face's distributed processing library
|
||||
```
|
||||
|
||||
### Step 2: Configure Your SLURM Environment
|
||||
|
||||
Find your partition information:
|
||||
|
||||
```bash
|
||||
sinfo --format="%R" # List available partitions
|
||||
sinfo -N -p your_partition -h -o "%N cpus=%c mem=%m" # Check resources
|
||||
```
|
||||
|
||||
Choose a **CPU partition** - no GPU needed for dataset porting.
|
||||
|
||||
### Step 3: Launch Parallel Porting Jobs
|
||||
|
||||
```bash
|
||||
python examples/port_datasets/slurm_port_shards.py \
|
||||
--raw-dir /your/data/droid/1.0.1 \
|
||||
--repo-id your_id/droid_1.0.1 \
|
||||
--logs-dir /your/logs \
|
||||
--job-name port_droid \
|
||||
--partition your_partition \
|
||||
--workers 2048 \
|
||||
--cpus-per-task 8 \
|
||||
--mem-per-cpu 1950M
|
||||
```
|
||||
|
||||
#### Parameter Guidelines
|
||||
|
||||
- **`--workers`**: Number of parallel jobs (max 2048 for DROID's shard count)
|
||||
- **`--cpus-per-task`**: 8 CPUs recommended for frame encoding parallelization
|
||||
- **`--mem-per-cpu`**: ~16GB total RAM (8×1950M) for loading raw frames
|
||||
|
||||
> [!TIP]
|
||||
> Start with fewer workers (e.g., 100) to test your cluster configuration before launching thousands of jobs.
|
||||
|
||||
### Step 4: Monitor Progress
|
||||
|
||||
Check running jobs:
|
||||
|
||||
```bash
|
||||
squeue -u $USER
|
||||
```
|
||||
|
||||
Monitor overall progress:
|
||||
|
||||
```bash
|
||||
jobs_status /your/logs
|
||||
```
|
||||
|
||||
Inspect individual job logs:
|
||||
|
||||
```bash
|
||||
less /your/logs/port_droid/slurm_jobs/JOB_ID_WORKER_ID.out
|
||||
```
|
||||
|
||||
Debug failed jobs:
|
||||
|
||||
```bash
|
||||
failed_logs /your/logs/port_droid
|
||||
```
|
||||
|
||||
### Step 5: Aggregate Shards
|
||||
|
||||
Once all porting jobs complete:
|
||||
|
||||
```bash
|
||||
python examples/port_datasets/slurm_aggregate_shards.py \
|
||||
--repo-id your_id/droid_1.0.1 \
|
||||
--logs-dir /your/logs \
|
||||
--job-name aggr_droid \
|
||||
--partition your_partition \
|
||||
--workers 2048 \
|
||||
--cpus-per-task 8 \
|
||||
--mem-per-cpu 1950M
|
||||
```
|
||||
|
||||
### Step 6: Upload to Hub
|
||||
|
||||
```bash
|
||||
python examples/port_datasets/slurm_upload.py \
|
||||
--repo-id your_id/droid_1.0.1 \
|
||||
--logs-dir /your/logs \
|
||||
--job-name upload_droid \
|
||||
--partition your_partition \
|
||||
--workers 50 \
|
||||
--cpus-per-task 4 \
|
||||
--mem-per-cpu 1950M
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Upload uses fewer workers (50) since it's network-bound rather than compute-bound.
|
||||
|
||||
## Dataset v3.0 File Structure
|
||||
|
||||
Your completed dataset will have this modern structure:
|
||||
|
||||
```
|
||||
dataset/
|
||||
├── meta/
|
||||
│ ├── episodes/
|
||||
│ │ └── chunk-000/
|
||||
│ │ └── file-000.parquet # Episode metadata
|
||||
│ ├── tasks.parquet # Task definitions
|
||||
│ ├── stats.json # Aggregated statistics
|
||||
│ └── info.json # Dataset information
|
||||
├── data/
|
||||
│ └── chunk-000/
|
||||
│ └── file-000.parquet # Consolidated episode data
|
||||
└── videos/
|
||||
└── camera_key/
|
||||
└── chunk-000/
|
||||
└── file-000.mp4 # Consolidated video files
|
||||
```
|
||||
|
||||
This replaces the old episode-per-file structure with efficient, optimally-sized chunks.
|
||||
|
||||
## Migrating from Dataset v2.1
|
||||
|
||||
If you have existing datasets in v2.1 format, use the migration tool:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id your_id/existing_dataset
|
||||
```
|
||||
|
||||
This automatically:
|
||||
|
||||
- Converts file structure to v3.0 format
|
||||
- Migrates metadata from JSON Lines to parquet
|
||||
- Aggregates statistics and creates per-episode stats
|
||||
- Updates version information
|
||||
|
||||
## Performance Benefits
|
||||
|
||||
Dataset v3.0 provides significant improvements for large datasets:
|
||||
|
||||
- **Faster loading**: 3-5x reduction in initialization time
|
||||
- **Memory efficiency**: Better RAM usage through memory mapping
|
||||
- **Scalable processing**: Handles millions of episodes efficiently
|
||||
- **Storage optimization**: Reduced file count and improved compression
|
||||
@@ -0,0 +1,288 @@
|
||||
# Reachy 2
|
||||
|
||||
Reachy 2 is an open-source humanoid robot made by Pollen Robotics, specifically designed for the development of embodied AI and real-world applications.
|
||||
Check out [Pollen Robotics website](https://www.pollen-robotics.com/reachy/), or access [Reachy 2 documentation](https://docs.pollen-robotics.com/) for more information on the platform!
|
||||
|
||||
## Teleoperate Reachy 2
|
||||
|
||||
Currently, there are two ways to teleoperate Reachy 2:
|
||||
|
||||
- Pollen Robotics’ VR teleoperation (not included in LeRobot).
|
||||
- Robot-to-robot teleoperation (use one Reachy 2 to control another).
|
||||
|
||||
## Reachy 2 Simulation
|
||||
|
||||
**(Linux only)** You can run Reachy 2 in simulation (Gazebo or MuJoCo) using the provided [Docker image](https://hub.docker.com/r/pollenrobotics/reachy2_core).
|
||||
|
||||
1. Install [Docker Engine](https://docs.docker.com/engine/).
|
||||
2. Run (for MuJoCo):
|
||||
|
||||
```
|
||||
docker run --rm -it \
|
||||
--name reachy \
|
||||
--privileged \
|
||||
--network host \
|
||||
--ipc host \
|
||||
--device-cgroup-rule='c 189:* rwm' \
|
||||
--group-add audio \
|
||||
-e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
|
||||
-e DISPLAY="$DISPLAY" \
|
||||
-e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
|
||||
-e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
|
||||
-v /dev:/dev \
|
||||
-v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
|
||||
-v "$HOME/.reachy.log":/home/reachy/.ros/log \
|
||||
-v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
|
||||
--entrypoint /package/launch.sh \
|
||||
pollenrobotics/reachy2_core:1.7.5.9_deploy \
|
||||
start_rviz:=true start_sdk_server:=true mujoco:=true
|
||||
```
|
||||
|
||||
> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance:
|
||||
>
|
||||
> ```
|
||||
> docker run --rm -it \
|
||||
> --name reachy \
|
||||
> --privileged \
|
||||
> --network host \
|
||||
> --ipc host \
|
||||
> --device-cgroup-rule='c 189:* rwm' \
|
||||
> --group-add audio \
|
||||
> -e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
|
||||
> -e DISPLAY="$DISPLAY" \
|
||||
> -e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
|
||||
> -e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
|
||||
> -e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \
|
||||
> -v /dev:/dev \
|
||||
> -v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
|
||||
> -v "$HOME/.reachy.log":/home/reachy/.ros/log \
|
||||
> -v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
|
||||
> --entrypoint /package/launch.sh \
|
||||
> pollenrobotics/reachy2_core:1.7.5.9_deploy \
|
||||
> start_rviz:=true start_sdk_server:=true mujoco:=true
|
||||
> ```
|
||||
|
||||
## Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- On your robot, check the **service images** meet the minimum versions:
|
||||
- **reachy2-core >= 1.7.5.2**
|
||||
- **webrtc >= 2.0.1.1**
|
||||
|
||||
Then, if you want to use VR teleoperation:
|
||||
|
||||
- Install the [Reachy 2 teleoperation application](https://docs.pollen-robotics.com/teleoperation/teleoperation-introduction/discover-teleoperation/).
|
||||
Use version **>=v1.2.0**
|
||||
|
||||
We recommend using two computers: one for teleoperation (Windows required) and another for recording with LeRobot.
|
||||
|
||||
### Install LeRobot
|
||||
|
||||
Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot.
|
||||
|
||||
Install LeRobot with Reachy 2 dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[reachy2]"
|
||||
```
|
||||
|
||||
### (Optional but recommended) Install pollen_data_acquisition_server
|
||||
|
||||
How you manage Reachy 2 recording sessions is up to you, but the **easiest** way is to use this server so you can control sessions directly from the VR teleoperation app.
|
||||
|
||||
> **Note:** Currently, only the VR teleoperation application works as a client for this server, so this step primarily targets teleoperation. You’re free to develop custom clients to manage sessions to your needs.
|
||||
|
||||
In your LeRobot environment, install the server from source:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/pollen-robotics/pollen_data_acquisition_server.git
|
||||
cd pollen_data_acquisition_server
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Find the [pollen_data_acquisition_server documentation here](https://github.com/pollen-robotics/pollen_data_acquisition_server).
|
||||
|
||||
## Step 1: Recording
|
||||
|
||||
### Get Reachy 2 IP address
|
||||
|
||||
Before starting teleoperation and data recording, find the [robot's IP address](https://docs.pollen-robotics.com/getting-started/setup-reachy2/connect-reachy2/).
|
||||
We strongly recommend connecting all devices (PC and robot) via **Ethernet**.
|
||||
|
||||
### Launch recording
|
||||
|
||||
There are two ways to manage recording sessions when using the Reachy 2 VR teleoperation application:
|
||||
|
||||
- **Using the data acquisition server (recommended for VR teleop)**: The VR app orchestrates sessions (via the server it tells LeRobot when to create datasets, start/stop episodes) while also controlling the robot’s motions.
|
||||
- **Using LeRobot’s record script**: LeRobot owns session control and decides when to start/stop episodes. If you also use the VR teleop app, it’s only for motion control.
|
||||
|
||||
### Option 1: Using Pollen data acquisition server (recommended for VR teleop)
|
||||
|
||||
Make sure you have installed pollen_data_acquisition_server, as explained in the Setup section.
|
||||
|
||||
Launch the data acquisition server to be able to manage your session directly from the teleoperation application:
|
||||
|
||||
```bash
|
||||
python -m pollen_data_acquisition_server.server
|
||||
```
|
||||
|
||||
Then get into the teleoperation application and choose "Data acquisition session".
|
||||
You can finally setup your session by following the screens displayed.
|
||||
|
||||
> Even without the VR app, you can use the `pollen_data_acquisition_server` with your own client implementation.
|
||||
|
||||
### Option 2: Using lerobot.record
|
||||
|
||||
Reachy 2 is fully supported by LeRobot’s recording features.
|
||||
If you choose this option but still want to use the VR teleoperation application, select "Standard session" in the app.
|
||||
|
||||
**Example: start a recording without the mobile base:**
|
||||
First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.id=r2-0000 \
|
||||
--robot.use_external_commands=true \
|
||||
--robot.with_mobile_base=false \
|
||||
--teleop.type=reachy2_teleoperator \
|
||||
--teleop.ip_address=192.168.0.200 \
|
||||
--teleop.with_mobile_base=false \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.single_task="Reachy 2 recording test" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
#### Specific Options
|
||||
|
||||
**Extended setup overview (all options included):**
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.use_external_commands=true \
|
||||
--robot.with_mobile_base=true \
|
||||
--robot.with_l_arm=true \
|
||||
--robot.with_r_arm=true \
|
||||
--robot.with_neck=true \
|
||||
--robot.with_antennas=true \
|
||||
--robot.with_left_teleop_camera=true \
|
||||
--robot.with_right_teleop_camera=true \
|
||||
--robot.with_torso_camera=false \
|
||||
--robot.disable_torque_on_disconnect=false \
|
||||
--robot.max_relative_target=5.0 \
|
||||
--teleop.type=reachy2_teleoperator \
|
||||
--teleop.ip_address=192.168.0.200 \
|
||||
--teleop.use_present_position=false \
|
||||
--teleop.with_mobile_base=false \
|
||||
--teleop.with_l_arm=true \
|
||||
--teleop.with_r_arm=true \
|
||||
--teleop.with_neck=true \
|
||||
--teleop.with_antennas=true \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.single_task="Reachy 2 recording test" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
##### `--robot.use_external_commands`
|
||||
|
||||
Determine whether LeRobot robot.send_action() sends commands to the robot.
|
||||
**Must** be set to false while using the VR teleoperation application, as the app already sends commands.
|
||||
|
||||
##### `--teleop.use_present_position`
|
||||
|
||||
Determine whether the teleoperator reads the goal or present position of the robot.
|
||||
Must be set to true if a compliant Reachy 2 is used to control another one.
|
||||
|
||||
##### Use the relevant parts
|
||||
|
||||
From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies.
|
||||
To avoid this, you can exclude specific parts from recording and replay using:
|
||||
|
||||
````
|
||||
--robot.with_<part>=false
|
||||
```,
|
||||
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
|
||||
It determine whether the corresponding part is recorded in the observations. True if not set.
|
||||
|
||||
By default, **all parts are recorded**.
|
||||
|
||||
The same per-part mechanism is available in `reachy2_teleoperator` as well.
|
||||
|
||||
````
|
||||
|
||||
--teleop.with\_<part>
|
||||
|
||||
```
|
||||
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
|
||||
Determine whether the corresponding part is recorded in the actions. True if not set.
|
||||
|
||||
> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator.
|
||||
For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
|
||||
|
||||
##### Use the relevant cameras
|
||||
|
||||
You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with:
|
||||
|
||||
```
|
||||
|
||||
--robot.with_left_teleop_camera=<true|false>
|
||||
--robot.with_right_teleop_camera=<true|false>
|
||||
--robot.with_torso_camera=<true|false>
|
||||
|
||||
````
|
||||
|
||||
|
||||
## Step 2: Replay
|
||||
|
||||
Make sure the robot is configured with the same parts as the dataset:
|
||||
|
||||
```bash
|
||||
python -m lerobot.replay \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.use_external_commands=false \
|
||||
--robot.with_mobile_base=false \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.episode=0
|
||||
--display_data=true
|
||||
````
|
||||
|
||||
## Step 3: Train
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/reachy2_test \
|
||||
--job_name=reachy2 \
|
||||
--policy.device=mps \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=pollen_robotics/record_test_policy
|
||||
```
|
||||
|
||||
## Step 4: Evaluate
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=pollen_robotics/eval_record_test \
|
||||
--dataset.single_task="Evaluate reachy2 policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--policy.path=outputs/train/reachy2_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
@@ -92,11 +92,11 @@ print(dataset.hf_dataset)
|
||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset.
|
||||
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
|
||||
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
|
||||
# frame indices associated to the first episode:
|
||||
episode_index = 0
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# Then we grab all the image frames from the first camera:
|
||||
camera_key = dataset.meta.camera_keys[0]
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This script demonstrates how to train a Diffusion Policy on the PushT environment,
|
||||
using a dataset processed in streaming mode.
|
||||
|
||||
Once you have trained a model with this script, you can try to evaluate it on
|
||||
examples/2_evaluate_pretrained_policy.py
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
|
||||
def main():
|
||||
# Create a directory to store the training checkpoint.
|
||||
output_directory = Path("outputs/train/example_streaming_dataset")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Selects the "best" device available
|
||||
device = (
|
||||
torch.device("cuda")
|
||||
if torch.cuda.is_available()
|
||||
else torch.device("mps")
|
||||
if torch.backends.mps.is_available()
|
||||
else torch.device("cpu")
|
||||
)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
training_steps = 10
|
||||
log_freq = 1
|
||||
|
||||
dataset_id = (
|
||||
"aractingi/droid_1.0.1" # 26M frames! Would require 4TB of disk space if installed locally (:
|
||||
)
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
# We can now instantiate our policy with this config and the dataset stats.
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
policy = ACTPolicy(cfg, dataset_stats=dataset_metadata.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# Delta timestamps are used to (1) augment frames used during training and (2) supervise the policy.
|
||||
# Here, we use delta-timestamps to only provide ground truth actions for supervision
|
||||
delta_timestamps = {
|
||||
ACTION: [t / dataset_metadata.fps for t in range(cfg.n_action_steps)],
|
||||
}
|
||||
|
||||
# Instantiating the training dataset in streaming mode allows to not consume up memory as the data is fetched
|
||||
# iteratively rather than being load into memory all at once. Retrieved frames are shuffled across epochs
|
||||
dataset = StreamingLeRobotDataset(dataset_id, delta_timestamps=delta_timestamps, tolerance_s=1e-3)
|
||||
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=16,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
prefetch_factor=2, # loads batches with multiprocessing while policy trains
|
||||
)
|
||||
|
||||
# Run training loop.
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {
|
||||
k: (v.type(torch.float32) if isinstance(v, torch.Tensor) and v.dtype != torch.bool else v)
|
||||
for k, v in batch.items()
|
||||
}
|
||||
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
|
||||
# batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save a policy checkpoint.
|
||||
policy.save_pretrained(output_directory)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,7 +1,6 @@
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
@@ -12,14 +11,12 @@ NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
policy = ACTPolicy.from_pretrained("<hf_username>/<policy_repo_id>")
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
@@ -28,7 +25,7 @@ dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
repo_id="<hf_username>/<eval_dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
@@ -46,12 +43,6 @@ listener, events = init_keyboard_listener()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
@@ -62,8 +53,6 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
|
||||
@@ -38,7 +38,7 @@ while True:
|
||||
keyboard_keys = keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
|
||||
log_rerun_data(observation=observation, action={**arm_action, **base_action})
|
||||
log_rerun_data(observation, {**arm_action, **base_action})
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot with degrees
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build dataset action and gripper features
|
||||
action_ee_and_gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints_processor,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
|
||||
) # Get all ee action features + gripper pos action features
|
||||
|
||||
# Build dataset observation features
|
||||
obs_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=robot.observation_features,
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
) # Get all ee observation features
|
||||
|
||||
dataset_features = combine_feature_dicts(obs_ee, action_ee_and_gripper)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
dataset.push_to_hub()
|
||||
@@ -1,215 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
action_to_transition,
|
||||
observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
ForwardKinematicsJointsToEE,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
|
||||
NUM_EPISODES = 10
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
],
|
||||
to_transition=action_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build dataset ee action features
|
||||
action_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=phone.action_features,
|
||||
use_videos=True,
|
||||
patterns=["action.ee"],
|
||||
)
|
||||
|
||||
# Get gripper pos action features
|
||||
gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints_processor,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
|
||||
)
|
||||
|
||||
# Build dataset ee observation features
|
||||
observation_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=robot.observation_features,
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
)
|
||||
|
||||
dataset_features = combine_feature_dicts(action_ee, gripper, observation_ee)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
phone.connect()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
dataset.push_to_hub()
|
||||
@@ -1,81 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import action_to_transition, transition_to_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
robot_ee_to_joints_processor.reset()
|
||||
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(dataset.num_frames):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
ee_action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
}
|
||||
|
||||
joint_action = robot_ee_to_joints_processor(ee_action)
|
||||
action_sent = robot.send_action(joint_action)
|
||||
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
robot.disconnect()
|
||||
@@ -1,93 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specif
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import action_to_transition, transition_to_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop_device = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action to joint action
|
||||
phone_to_robot_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot.")
|
||||
while True:
|
||||
# Get teleop observation
|
||||
phone_obs = teleop_device.get_action()
|
||||
|
||||
# Phone -> EE pose -> Joints transition
|
||||
joint_action = phone_to_robot_joints_processor(phone_obs)
|
||||
|
||||
if joint_action:
|
||||
robot.send_action(joint_action)
|
||||
|
||||
time.sleep(0.01)
|
||||
@@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def find_missing_workers(completions_dir, world_size):
|
||||
"""Find workers that are not completed and returns their indices."""
|
||||
full = list(range(world_size))
|
||||
|
||||
completed = []
|
||||
for path in completions_dir.glob("*"):
|
||||
if path.name in [".", ".."]:
|
||||
continue
|
||||
index = path.name.lstrip("0")
|
||||
index = 0 if index == "" else int(index)
|
||||
completed.append(index)
|
||||
|
||||
missing_workers = set(full) - set(completed)
|
||||
return missing_workers
|
||||
|
||||
|
||||
def find_output_files(slurm_dir, worker_indices):
|
||||
"""Find output files associated to worker indices, and return tuples
|
||||
of (worker index, output file path)
|
||||
"""
|
||||
out_files = []
|
||||
for path in slurm_dir.glob("*.out"):
|
||||
_, worker_id = path.name.replace(".out", "").split("_")
|
||||
worker_id = int(worker_id)
|
||||
if worker_id in worker_indices:
|
||||
out_files.append((worker_id, path))
|
||||
return out_files
|
||||
|
||||
|
||||
def display_error_files(logs_dir, job_name):
|
||||
executor_path = Path(logs_dir) / job_name / "executor.json"
|
||||
completions_dir = Path(logs_dir) / job_name / "completions"
|
||||
|
||||
with open(executor_path) as f:
|
||||
executor = json.load(f)
|
||||
|
||||
missing_workers = find_missing_workers(completions_dir, executor["world_size"])
|
||||
|
||||
for missing in sorted(missing_workers)[::-1]:
|
||||
print(missing)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=str,
|
||||
help="Path to logs directory for `datatrove`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="port_droid",
|
||||
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
display_error_files(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,430 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
|
||||
|
||||
DROID_SHARDS = 2048
|
||||
DROID_FPS = 15
|
||||
DROID_ROBOT_TYPE = "Franka"
|
||||
|
||||
# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
|
||||
DROID_FEATURES = {
|
||||
# true on first step of the episode
|
||||
"is_first": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
# true on last step of the episode
|
||||
"is_last": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
# true on last step of the episode if it is a terminal step, True for demos
|
||||
"is_terminal": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
# language_instruction is also stored as "task" to follow LeRobot standard
|
||||
"language_instruction": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"language_instruction_2": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"language_instruction_3": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"observation.state.gripper_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": {
|
||||
"axes": ["gripper"],
|
||||
},
|
||||
},
|
||||
"observation.state.cartesian_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||
},
|
||||
},
|
||||
"observation.state.joint_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {
|
||||
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
|
||||
},
|
||||
},
|
||||
# Add this new feature to follow LeRobot standard of using joint position + gripper
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (8,),
|
||||
"names": {
|
||||
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
|
||||
},
|
||||
},
|
||||
# Initially called wrist_image_left
|
||||
"observation.images.wrist_left": {
|
||||
"dtype": "video",
|
||||
"shape": (180, 320, 3),
|
||||
"names": [
|
||||
"height",
|
||||
"width",
|
||||
"channels",
|
||||
],
|
||||
},
|
||||
# Initially called exterior_image_1_left
|
||||
"observation.images.exterior_1_left": {
|
||||
"dtype": "video",
|
||||
"shape": (180, 320, 3),
|
||||
"names": [
|
||||
"height",
|
||||
"width",
|
||||
"channels",
|
||||
],
|
||||
},
|
||||
# Initially called exterior_image_2_left
|
||||
"observation.images.exterior_2_left": {
|
||||
"dtype": "video",
|
||||
"shape": (180, 320, 3),
|
||||
"names": [
|
||||
"height",
|
||||
"width",
|
||||
"channels",
|
||||
],
|
||||
},
|
||||
"action.gripper_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": {
|
||||
"axes": ["gripper"],
|
||||
},
|
||||
},
|
||||
"action.gripper_velocity": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": {
|
||||
"axes": ["gripper"],
|
||||
},
|
||||
},
|
||||
"action.cartesian_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||
},
|
||||
},
|
||||
"action.cartesian_velocity": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||
},
|
||||
},
|
||||
"action.joint_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {
|
||||
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
|
||||
},
|
||||
},
|
||||
"action.joint_velocity": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {
|
||||
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
|
||||
},
|
||||
},
|
||||
# This feature was called "action" in RLDS dataset and consists of [6x joint velocities, 1x gripper position]
|
||||
"action.original": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
|
||||
},
|
||||
},
|
||||
# Add this new feature to follow LeRobot standard of using joint position + gripper
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (8,),
|
||||
"names": {
|
||||
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
|
||||
},
|
||||
},
|
||||
"discount": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"reward": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
# Meta data that are the same for all frames in the episode
|
||||
"task_category": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"building": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"collector_id": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"date": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"camera_extrinsics.wrist_left": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||
},
|
||||
},
|
||||
"camera_extrinsics.exterior_1_left": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||
},
|
||||
},
|
||||
"camera_extrinsics.exterior_2_left": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||
},
|
||||
},
|
||||
"is_episode_successful": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def is_episode_successful(tf_episode_metadata):
|
||||
# Adapted from: https://github.com/droid-dataset/droid_policy_learning/blob/dd1020eb20d981f90b5ff07dc80d80d5c0cb108b/robomimic/utils/rlds_utils.py#L8
|
||||
return "/success/" in tf_episode_metadata["file_path"].numpy().decode()
|
||||
|
||||
|
||||
def generate_lerobot_frames(tf_episode):
|
||||
m = tf_episode["episode_metadata"]
|
||||
frame_meta = {
|
||||
"task_category": m["building"].numpy().decode(),
|
||||
"building": m["building"].numpy().decode(),
|
||||
"collector_id": m["collector_id"].numpy().decode(),
|
||||
"date": m["date"].numpy().decode(),
|
||||
"camera_extrinsics.wrist_left": m["extrinsics_wrist_cam"].numpy(),
|
||||
"camera_extrinsics.exterior_1_left": m["extrinsics_exterior_cam_1"].numpy(),
|
||||
"camera_extrinsics.exterior_2_left": m["extrinsics_exterior_cam_2"].numpy(),
|
||||
"is_episode_successful": np.array([is_episode_successful(m)]),
|
||||
}
|
||||
for f in tf_episode["steps"]:
|
||||
# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
|
||||
frame = {
|
||||
"is_first": np.array([f["is_first"].numpy()]),
|
||||
"is_last": np.array([f["is_last"].numpy()]),
|
||||
"is_terminal": np.array([f["is_terminal"].numpy()]),
|
||||
"language_instruction": f["language_instruction"].numpy().decode(),
|
||||
"language_instruction_2": f["language_instruction_2"].numpy().decode(),
|
||||
"language_instruction_3": f["language_instruction_3"].numpy().decode(),
|
||||
"observation.state.gripper_position": f["observation"]["gripper_position"].numpy(),
|
||||
"observation.state.cartesian_position": f["observation"]["cartesian_position"].numpy(),
|
||||
"observation.state.joint_position": f["observation"]["joint_position"].numpy(),
|
||||
"observation.images.wrist_left": f["observation"]["wrist_image_left"].numpy(),
|
||||
"observation.images.exterior_1_left": f["observation"]["exterior_image_1_left"].numpy(),
|
||||
"observation.images.exterior_2_left": f["observation"]["exterior_image_2_left"].numpy(),
|
||||
"action.gripper_position": f["action_dict"]["gripper_position"].numpy(),
|
||||
"action.gripper_velocity": f["action_dict"]["gripper_velocity"].numpy(),
|
||||
"action.cartesian_position": f["action_dict"]["cartesian_position"].numpy(),
|
||||
"action.cartesian_velocity": f["action_dict"]["cartesian_velocity"].numpy(),
|
||||
"action.joint_position": f["action_dict"]["joint_position"].numpy(),
|
||||
"action.joint_velocity": f["action_dict"]["joint_velocity"].numpy(),
|
||||
"discount": np.array([f["discount"].numpy()]),
|
||||
"reward": np.array([f["reward"].numpy()]),
|
||||
"action.original": f["action"].numpy(),
|
||||
}
|
||||
|
||||
# language_instruction is also stored as "task" to follow LeRobot standard
|
||||
frame["task"] = frame["language_instruction"]
|
||||
|
||||
# Add this new feature to follow LeRobot standard of using joint position + gripper
|
||||
frame["observation.state"] = np.concatenate(
|
||||
[frame["observation.state.joint_position"], frame["observation.state.gripper_position"]]
|
||||
)
|
||||
frame["action"] = np.concatenate([frame["action.joint_position"], frame["action.gripper_position"]])
|
||||
|
||||
# Meta data that are the same for all frames in the episode
|
||||
frame.update(frame_meta)
|
||||
|
||||
# Cast fp64 to fp32
|
||||
for key in frame:
|
||||
if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64:
|
||||
frame[key] = frame[key].astype(np.float32)
|
||||
|
||||
yield frame
|
||||
|
||||
|
||||
def port_droid(
|
||||
raw_dir: Path,
|
||||
repo_id: str,
|
||||
push_to_hub: bool = False,
|
||||
num_shards: int | None = None,
|
||||
shard_index: int | None = None,
|
||||
):
|
||||
dataset_name = raw_dir.parent.name
|
||||
version = raw_dir.name
|
||||
data_dir = raw_dir.parent.parent
|
||||
|
||||
builder = tfds.builder(f"{dataset_name}/{version}", data_dir=data_dir, version="")
|
||||
|
||||
if num_shards is not None:
|
||||
tfds_num_shards = builder.info.splits["train"].num_shards
|
||||
if tfds_num_shards != DROID_SHARDS:
|
||||
raise ValueError(
|
||||
f"Number of shards of Droid dataset is expected to be {DROID_SHARDS} but is {tfds_num_shards}."
|
||||
)
|
||||
if num_shards != tfds_num_shards:
|
||||
raise ValueError(
|
||||
f"We only shard over the fixed number of shards provided by tensorflow dataset ({tfds_num_shards}), but {num_shards} shards provided instead."
|
||||
)
|
||||
if shard_index >= tfds_num_shards:
|
||||
raise ValueError(
|
||||
f"Shard index is greater than the num of shards ({shard_index} >= {num_shards})."
|
||||
)
|
||||
|
||||
raw_dataset = builder.as_dataset(split=f"train[{shard_index}shard]")
|
||||
else:
|
||||
raw_dataset = builder.as_dataset(split="train")
|
||||
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
robot_type=DROID_ROBOT_TYPE,
|
||||
fps=DROID_FPS,
|
||||
features=DROID_FEATURES,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
num_episodes = raw_dataset.cardinality().numpy().item()
|
||||
logging.info(f"Number of episodes {num_episodes}")
|
||||
|
||||
for episode_index, episode in enumerate(raw_dataset):
|
||||
elapsed_time = time.time() - start_time
|
||||
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
|
||||
|
||||
logging.info(
|
||||
f"{episode_index} / {num_episodes} episodes processed (after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)"
|
||||
)
|
||||
|
||||
for frame in generate_lerobot_frames(episode):
|
||||
lerobot_dataset.add_frame(frame)
|
||||
|
||||
lerobot_dataset.save_episode()
|
||||
logging.info("Save_episode")
|
||||
|
||||
if push_to_hub:
|
||||
lerobot_dataset.push_to_hub(
|
||||
# Add openx tag, since it belongs to the openx collection of datasets
|
||||
tags=["openx"],
|
||||
private=False,
|
||||
)
|
||||
|
||||
|
||||
def validate_dataset(repo_id):
|
||||
"""Sanity check that ensure meta data can be loaded and all files are present."""
|
||||
meta = LeRobotDatasetMetadata(repo_id)
|
||||
|
||||
if meta.total_episodes == 0:
|
||||
raise ValueError("Number of episodes is 0.")
|
||||
|
||||
for ep_idx in range(meta.total_episodes):
|
||||
data_path = meta.root / meta.get_data_file_path(ep_idx)
|
||||
|
||||
if not data_path.exists():
|
||||
raise ValueError(f"Parquet file is missing in: {data_path}")
|
||||
|
||||
for vid_key in meta.video_keys:
|
||||
vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key)
|
||||
if not vid_path.exists():
|
||||
raise ValueError(f"Video file is missing in: {vid_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Upload to hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-shards",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of shards. Can be either None to load the full dataset, or 2048 to load one of the 2048 tensorflow dataset files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard-index",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Index of the shard. Can be either None to load the full dataset, or in [0,2047] to load one of the 2048 tensorflow dataset files.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
port_droid(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
class AggregateDatasets(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
aggregated_repo_id: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.aggr_repo_id = aggregated_repo_id
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
init_logging()
|
||||
|
||||
# Since aggregate_datasets already handles parallel processing internally,
|
||||
# we only need one worker to run the entire aggregation
|
||||
if rank == 0:
|
||||
logging.info(f"Starting aggregation of {len(self.repo_ids)} datasets into {self.aggr_repo_id}")
|
||||
aggregate_datasets(self.repo_ids, self.aggr_repo_id)
|
||||
logging.info("Aggregation complete!")
|
||||
else:
|
||||
logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")
|
||||
|
||||
|
||||
def make_aggregate_executor(
|
||||
repo_ids, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
AggregateDatasets(repo_ids, repo_id),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
# For aggregation, we only need 1 task since aggregate_datasets handles everything
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": 1, # Only need 1 task for aggregation
|
||||
"workers": 1, # Only need 1 worker
|
||||
"time": "08:00:00",
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
help="Path to logs directory for `datatrove`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="aggr_droid",
|
||||
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=1, # Changed default to 1 since aggregation doesn't need multiple workers
|
||||
help="Number of slurm workers. For aggregation, this should be 1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of cpus that each slurm worker will use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="1950M",
|
||||
help="Memory per cpu that each worker will use.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
|
||||
repo_ids = [f"{args.repo_id}_world_{DROID_SHARDS}_rank_{rank}" for rank in range(DROID_SHARDS)]
|
||||
aggregate_executor = make_aggregate_executor(repo_ids, **kwargs)
|
||||
aggregate_executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
|
||||
|
||||
class PortDroidShards(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
raw_dir: Path | str,
|
||||
repo_id: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.raw_dir = Path(raw_dir)
|
||||
self.repo_id = repo_id
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
|
||||
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
disable_progress_bars()
|
||||
|
||||
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
|
||||
|
||||
try:
|
||||
validate_dataset(shard_repo_id)
|
||||
return
|
||||
except Exception:
|
||||
pass # nosec B110 - Dataset doesn't exist yet, continue with porting
|
||||
|
||||
port_droid(
|
||||
self.raw_dir,
|
||||
shard_repo_id,
|
||||
push_to_hub=False,
|
||||
num_shards=world_size,
|
||||
shard_index=rank,
|
||||
)
|
||||
|
||||
validate_dataset(shard_repo_id)
|
||||
|
||||
|
||||
def make_port_executor(
|
||||
raw_dir, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
PortDroidShards(raw_dir, repo_id),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": DROID_SHARDS,
|
||||
"workers": workers,
|
||||
"time": "08:00:00",
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
help="Path to logs directory for `datatrove`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="port_droid",
|
||||
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Number of slurm workers. It should be less than the maximum number of shards.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of cpus that each slurm worker will use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="1950M",
|
||||
help="Memory per cpu that each worker will use.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
port_executor = make_port_executor(**kwargs)
|
||||
port_executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,281 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
class UploadDataset(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
branch: str | None = None,
|
||||
revision: str | None = None,
|
||||
tags: list | None = None,
|
||||
license: str | None = "apache-2.0",
|
||||
private: bool = False,
|
||||
distant_repo_id: str | None = None,
|
||||
**card_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.distant_repo_id = self.repo_id if distant_repo_id is None else distant_repo_id
|
||||
self.branch = branch
|
||||
self.tags = tags
|
||||
self.license = license
|
||||
self.private = private
|
||||
self.card_kwargs = card_kwargs
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
|
||||
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") != "1":
|
||||
logging.warning(
|
||||
'HF_HUB_ENABLE_HF_TRANSFER is not set to "1". Install hf_transfer and set the env '
|
||||
"variable for faster uploads:\npip install hf-transfer\nexport HF_HUB_ENABLE_HF_TRANSFER=1"
|
||||
)
|
||||
|
||||
self.create_repo()
|
||||
|
||||
def create_repo(self):
|
||||
logging.info(f"Loading meta data from {self.repo_id}...")
|
||||
meta = LeRobotDatasetMetadata(self.repo_id)
|
||||
|
||||
logging.info(f"Creating repo {self.distant_repo_id}...")
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(
|
||||
repo_id=self.distant_repo_id,
|
||||
private=self.private,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
if self.branch:
|
||||
hub_api.create_branch(
|
||||
repo_id=self.distant_repo_id,
|
||||
branch=self.branch,
|
||||
revision=self.revision,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
if not hub_api.file_exists(
|
||||
self.distant_repo_id, REPOCARD_NAME, repo_type="dataset", revision=self.branch
|
||||
):
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=self.tags, dataset_info=meta.info, license=self.license, **self.card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.distant_repo_id, repo_type="dataset", revision=self.branch)
|
||||
|
||||
hub_api.create_tag(self.distant_repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
def list_files_recursively(directory):
|
||||
base_path = Path(directory)
|
||||
return [str(file.relative_to(base_path)) for file in base_path.rglob("*") if file.is_file()]
|
||||
|
||||
logging.info(f"Listing all local files from {self.repo_id}...")
|
||||
self.file_paths = list_files_recursively(meta.root)
|
||||
self.file_paths = sorted(self.file_paths)
|
||||
|
||||
def create_chunks(self, lst, n):
|
||||
from itertools import islice
|
||||
|
||||
it = iter(lst)
|
||||
return [list(islice(it, size)) for size in [len(lst) // n + (i < len(lst) % n) for i in range(n)]]
|
||||
|
||||
def create_commits(self, additions):
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
from huggingface_hub import create_commit
|
||||
from huggingface_hub.utils import HfHubHTTPError
|
||||
|
||||
FILES_BETWEEN_COMMITS = 10 # noqa: N806
|
||||
BASE_DELAY = 0.1 # noqa: N806
|
||||
MAX_RETRIES = 12 # noqa: N806
|
||||
|
||||
# Split the files into smaller chunks for faster commit
|
||||
# and avoiding "A commit has happened since" error
|
||||
num_chunks = math.ceil(len(additions) / FILES_BETWEEN_COMMITS)
|
||||
chunks = self.create_chunks(additions, num_chunks)
|
||||
|
||||
for chunk in chunks:
|
||||
retries = 0
|
||||
while True:
|
||||
try:
|
||||
create_commit(
|
||||
self.distant_repo_id,
|
||||
repo_type="dataset",
|
||||
operations=chunk,
|
||||
commit_message=f"DataTrove upload ({len(chunk)} files)",
|
||||
revision=self.branch,
|
||||
)
|
||||
# TODO: every 100 chunks super_squach_commits()
|
||||
logging.info("create_commit completed!")
|
||||
break
|
||||
except HfHubHTTPError as e:
|
||||
if "A commit has happened since" in e.server_message:
|
||||
if retries >= MAX_RETRIES:
|
||||
logging.error(f"Failed to create commit after {MAX_RETRIES=}. Giving up.")
|
||||
raise e
|
||||
logging.info("Commit creation race condition issue. Waiting...")
|
||||
time.sleep(BASE_DELAY * 2**retries + random.uniform(0, 2))
|
||||
retries += 1
|
||||
else:
|
||||
raise e
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import logging
|
||||
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
from huggingface_hub import CommitOperationAdd, preupload_lfs_files
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
disable_progress_bars()
|
||||
|
||||
chunks = self.create_chunks(self.file_paths, world_size)
|
||||
file_paths = chunks[rank]
|
||||
|
||||
if len(file_paths) == 0:
|
||||
raise ValueError(file_paths)
|
||||
|
||||
logging.info("Pre-uploading LFS files...")
|
||||
for i, path in enumerate(file_paths):
|
||||
logging.info(f"{i}: {path}")
|
||||
|
||||
meta = LeRobotDatasetMetadata(self.repo_id)
|
||||
additions = [
|
||||
CommitOperationAdd(path_in_repo=path, path_or_fileobj=meta.root / path) for path in file_paths
|
||||
]
|
||||
preupload_lfs_files(
|
||||
repo_id=self.distant_repo_id, repo_type="dataset", additions=additions, revision=self.branch
|
||||
)
|
||||
|
||||
logging.info("Creating commits...")
|
||||
self.create_commits(additions)
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
def make_upload_executor(
|
||||
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
UploadDataset(repo_id),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": DROID_SHARDS,
|
||||
"workers": workers,
|
||||
"time": "08:00:00",
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": DROID_SHARDS,
|
||||
"workers": 1,
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
help="Path to logs directory for `datatrove`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="upload_droid",
|
||||
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of slurm workers. It should be less than the maximum number of shards.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of cpus that each slurm worker will use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="1950M",
|
||||
help="Memory per cpu that each worker will use.",
|
||||
)
|
||||
|
||||
init_logging()
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
upload_executor = make_upload_executor(**kwargs)
|
||||
upload_executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+4
-6
@@ -73,7 +73,6 @@ dependencies = [
|
||||
"pynput>=1.7.7",
|
||||
"pyserial>=3.5",
|
||||
"wandb>=0.20.0",
|
||||
"scipy>=1.15.2",
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
@@ -85,7 +84,6 @@ dependencies = [
|
||||
|
||||
# Support dependencies
|
||||
"deepdiff>=7.0.1,<9.0.0",
|
||||
"flask>=3.0.3,<4.0.0",
|
||||
"imageio[ffmpeg]>=2.34.0,<3.0.0",
|
||||
"termcolor>=2.4.0,<4.0.0",
|
||||
]
|
||||
@@ -96,7 +94,7 @@ dependencies = [
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1"]
|
||||
placo-dep = ["placo>=0.9.6"]
|
||||
transformers-dep = ["transformers<=4.52.0"]
|
||||
transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
|
||||
|
||||
# Motors
|
||||
@@ -107,12 +105,12 @@ dynamixel = ["dynamixel-sdk>=3.7.31"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
|
||||
# stretch = [
|
||||
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
|
||||
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||
@@ -143,6 +141,7 @@ all = [
|
||||
"lerobot[gamepad]",
|
||||
"lerobot[hopejr]",
|
||||
"lerobot[lekiwi]",
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[pi0]",
|
||||
@@ -154,8 +153,7 @@ all = [
|
||||
"lerobot[video_benchmark]",
|
||||
"lerobot[aloha]",
|
||||
"lerobot[pusht]",
|
||||
"lerobot[xarm]",
|
||||
"lerobot[phone]",
|
||||
"lerobot[xarm]"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
+3
-5
@@ -1,6 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,5 +12,5 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_phone import PhoneConfig
|
||||
from .phone import Phone
|
||||
from .configuration_reachy2_camera import Reachy2CameraConfig
|
||||
from .reachy2_camera import Reachy2Camera
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig, ColorMode
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("reachy2_camera")
|
||||
@dataclass
|
||||
class Reachy2CameraConfig(CameraConfig):
|
||||
"""Configuration class for Reachy 2 camera devices.
|
||||
|
||||
This class provides configuration options for Reachy 2 cameras,
|
||||
supporting both the teleop and depth cameras. It includes settings
|
||||
for resolution, frame rate, color mode, and the selection of the cameras.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
image_type="left",
|
||||
ip_address="192.168.0.200", # IP address of the robot
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
) # Left teleop camera, 640x480 @ 15FPS
|
||||
```
|
||||
|
||||
Attributes:
|
||||
name: Name of the camera device. Can be "teleop" or "depth".
|
||||
image_type: Type of image stream. For "teleop" camera, can be "left" or "right".
|
||||
For "depth" camera, can be "rgb" or "depth". (depth is not supported yet)
|
||||
fps: Requested frames per second for the color stream.
|
||||
width: Requested frame width in pixels for the color stream.
|
||||
height: Requested frame height in pixels for the color stream.
|
||||
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
|
||||
ip_address: IP address of the robot. Defaults to "localhost".
|
||||
port: Port number for the camera server. Defaults to 50065.
|
||||
|
||||
Note:
|
||||
- Only 3-channel color output (RGB/BGR) is currently supported.
|
||||
"""
|
||||
|
||||
name: str
|
||||
image_type: str
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
ip_address: str | None = "localhost"
|
||||
port: int = 50065
|
||||
# use_depth: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.name not in ["teleop", "depth"]:
|
||||
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
|
||||
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
|
||||
self.name == "depth" and self.image_type not in ["rgb", "depth"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
|
||||
)
|
||||
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
@@ -0,0 +1,288 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2
|
||||
import numpy as np
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Reachy2Camera(Camera):
|
||||
"""
|
||||
Manages Reachy 2 camera using Reachy 2 CameraManager.
|
||||
|
||||
This class provides a high-level interface to connect to, configure, and read
|
||||
frames from Reachy 2 cameras. It supports both synchronous and asynchronous
|
||||
frame reading.
|
||||
|
||||
An Reachy2Camera instance requires a camera name (e.g., "teleop") and an image
|
||||
type (e.g., "left") to be specified in the configuration.
|
||||
|
||||
The camera's default settings (FPS, resolution, color mode) are used unless
|
||||
overridden in the configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Reachy2CameraConfig):
|
||||
"""
|
||||
Initializes the Reachy2Camera instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the camera.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
self.fps = config.fps
|
||||
self.color_mode = config.color_mode
|
||||
|
||||
self.cam_manager: CameraManager | None = None
|
||||
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
if self.config.name == "teleop":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
|
||||
elif self.config.name == "depth":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
|
||||
else:
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
"""
|
||||
Connects to the Reachy2 CameraManager as specified in the configuration.
|
||||
"""
|
||||
self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port)
|
||||
self.cam_manager.initialize_cameras()
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@staticmethod
|
||||
def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detects available Reachy 2 cameras.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains 'name', 'stereo',
|
||||
and the default profile properties (width, height, fps).
|
||||
"""
|
||||
initialized_cameras = []
|
||||
camera_manager = CameraManager(host=ip_address, port=port)
|
||||
|
||||
for camera in [camera_manager.teleop, camera_manager.depth]:
|
||||
if camera is None:
|
||||
continue
|
||||
|
||||
height, width, _, _, _, _, _ = camera.get_parameters()
|
||||
|
||||
camera_info = {
|
||||
"name": camera._cam_info.name,
|
||||
"stereo": camera._cam_info.stereo,
|
||||
"default_profile": {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"fps": 30,
|
||||
},
|
||||
}
|
||||
initialized_cameras.append(camera_info)
|
||||
|
||||
camera_manager.disconnect()
|
||||
return initialized_cameras
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
This is a blocking call.
|
||||
|
||||
Args:
|
||||
color_mode (Optional[ColorMode]): If specified, overrides the default
|
||||
color mode (`self.color_mode`) for this read operation (e.g.,
|
||||
request RGB even if default is BGR).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured frame as a NumPy array in the format
|
||||
(height, width, channels), using the specified or default
|
||||
color mode and applying any configured rotation.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
frame = None
|
||||
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
else:
|
||||
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
||||
if self.config.image_type == "left":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0]
|
||||
elif self.config.image_type == "right":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0]
|
||||
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
|
||||
if self.config.image_type == "depth":
|
||||
frame = self.cam_manager.depth.get_depth_frame()[0]
|
||||
elif self.config.image_type == "rgb":
|
||||
frame = self.cam_manager.depth.get_frame(size=(640, 480))[0]
|
||||
|
||||
if frame is None:
|
||||
return np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.config.color_mode == "rgb":
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self):
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = color_image
|
||||
self.new_frame_event.set()
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
This method retrieves the most recent frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 200ms (0.2 seconds).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame as a NumPy array in the format
|
||||
(height, width, channels), processed according to configuration.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
TimeoutError: If no frame becomes available within the specified timeout.
|
||||
RuntimeError: If an unexpected error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {thread_alive}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Stops the background read thread (if running).
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is already disconnected.
|
||||
"""
|
||||
if not self.is_connected and self.thread is None:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
if self.thread is not None:
|
||||
self._stop_read_thread()
|
||||
|
||||
if self.cam_manager is not None:
|
||||
self.cam_manager.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -37,8 +37,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
from .realsense.camera_realsense import RealSenseCamera
|
||||
|
||||
cameras[key] = RealSenseCamera(cfg)
|
||||
|
||||
elif cfg.type == "reachy2_camera":
|
||||
from .reachy2_camera.reachy2_camera import Reachy2Camera
|
||||
|
||||
cameras[key] = Reachy2Camera(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ class DatasetConfig:
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
streaming: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
@@ -53,6 +53,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
@@ -24,7 +24,6 @@ class FeatureType(str, Enum):
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
LANGUAGE = "LANGUAGE"
|
||||
|
||||
|
||||
class NormalizationMode(str, Enum):
|
||||
|
||||
@@ -21,14 +21,8 @@ OBS_ENV_STATE = "observation.environment_state"
|
||||
OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
OBS_LANGUAGE = "observation.language"
|
||||
ACTION = "action"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
DONE = "next.done"
|
||||
|
||||
OBS_LANGUAGE_TOKENS = "observation.language.tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = "observation.language.attention_mask"
|
||||
|
||||
ROBOTS = "robots"
|
||||
ROBOT_TYPE = "robot_type"
|
||||
@@ -45,9 +39,6 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
PREPROCESSOR_DEFAULT_NAME = "robot_preprocessor"
|
||||
POSTPROCESSOR_DEFAULT_NAME = "robot_postprocessor"
|
||||
|
||||
if "LEROBOT_HOME" in os.environ:
|
||||
raise ValueError(
|
||||
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
||||
@@ -61,3 +52,8 @@ HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expandu
|
||||
# calibration dir
|
||||
default_calibration_path = HF_LEROBOT_HOME / "calibration"
|
||||
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
|
||||
|
||||
|
||||
# streaming datasets
|
||||
LOOKBACK_BACKTRACKTABLE = 100
|
||||
LOOKAHEAD_BACKTRACKTABLE = 100
|
||||
|
||||
@@ -0,0 +1,502 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_video_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
"""Validates that all dataset metadata have consistent properties.
|
||||
|
||||
Ensures all datasets have the same fps, robot_type, and features to guarantee
|
||||
compatibility when aggregating them into a single dataset.
|
||||
|
||||
Args:
|
||||
all_metadata: List of LeRobotDatasetMetadata objects to validate.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing (fps, robot_type, features) from the first metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If any metadata has different fps, robot_type, or features
|
||||
than the first metadata in the list.
|
||||
"""
|
||||
|
||||
fps = all_metadata[0].fps
|
||||
robot_type = all_metadata[0].robot_type
|
||||
features = all_metadata[0].features
|
||||
|
||||
for meta in tqdm.tqdm(all_metadata, desc="Validate all meta data"):
|
||||
if fps != meta.fps:
|
||||
raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.")
|
||||
if robot_type != meta.robot_type:
|
||||
raise ValueError(
|
||||
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
|
||||
)
|
||||
if features != meta.features:
|
||||
raise ValueError(
|
||||
f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||
)
|
||||
|
||||
return fps, robot_type, features
|
||||
|
||||
|
||||
def update_data_df(df, src_meta, dst_meta):
|
||||
"""Updates a data DataFrame with new indices and task mappings for aggregation.
|
||||
|
||||
Adjusts episode indices, frame indices, and task indices to account for
|
||||
previously aggregated data in the destination dataset.
|
||||
|
||||
Args:
|
||||
df: DataFrame containing the data to be updated.
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices.
|
||||
"""
|
||||
|
||||
def _update(row):
|
||||
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
|
||||
row["index"] = row["index"] + dst_meta.info["total_frames"]
|
||||
task = src_meta.tasks.iloc[row["task_index"]].name
|
||||
row["task_index"] = dst_meta.tasks.loc[task].task_index.item()
|
||||
return row
|
||||
|
||||
return df.apply(_update, axis=1)
|
||||
|
||||
|
||||
def update_meta_data(
|
||||
df,
|
||||
dst_meta,
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
):
|
||||
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
||||
|
||||
Adjusts all indices and timestamps to account for previously aggregated
|
||||
data and videos in the destination dataset.
|
||||
|
||||
Args:
|
||||
df: DataFrame containing the metadata to be updated.
|
||||
dst_meta: Destination dataset metadata.
|
||||
meta_idx: Dictionary containing current metadata chunk and file indices.
|
||||
data_idx: Dictionary containing current data chunk and file indices.
|
||||
videos_idx: Dictionary containing current video indices and timestamps.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
|
||||
"""
|
||||
|
||||
def _update(row):
|
||||
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"]
|
||||
row["data/chunk_index"] = row["data/chunk_index"] + data_idx["chunk"]
|
||||
row["data/file_index"] = row["data/file_index"] + data_idx["file"]
|
||||
for key, video_idx in videos_idx.items():
|
||||
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + video_idx["chunk"]
|
||||
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + video_idx["file"]
|
||||
row[f"videos/{key}/from_timestamp"] = (
|
||||
row[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
row[f"videos/{key}/to_timestamp"] = (
|
||||
row[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
|
||||
row["dataset_from_index"] = row["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
row["dataset_to_index"] = row["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
|
||||
return row
|
||||
|
||||
return df.apply(_update, axis=1)
|
||||
|
||||
|
||||
def aggregate_datasets(
|
||||
repo_ids: list[str],
|
||||
aggr_repo_id: str,
|
||||
roots: list[Path] | None = None,
|
||||
aggr_root: Path | None = None,
|
||||
data_files_size_in_mb: float | None = None,
|
||||
video_files_size_in_mb: float | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
|
||||
This is the main function that orchestrates the aggregation process by:
|
||||
1. Loading and validating all source dataset metadata
|
||||
2. Creating a new destination dataset with unified tasks
|
||||
3. Aggregating videos, data, and metadata from all source datasets
|
||||
4. Finalizing the aggregated dataset with proper statistics
|
||||
|
||||
Args:
|
||||
repo_ids: List of repository IDs for the datasets to aggregate.
|
||||
aggr_repo_id: Repository ID for the aggregated output dataset.
|
||||
roots: Optional list of root paths for the source datasets.
|
||||
aggr_root: Optional root path for the aggregated dataset.
|
||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
"""
|
||||
logging.info("Start aggregate_datasets")
|
||||
|
||||
if data_files_size_in_mb is None:
|
||||
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_files_size_in_mb is None:
|
||||
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
if chunk_size is None:
|
||||
chunk_size = DEFAULT_CHUNK_SIZE
|
||||
|
||||
all_metadata = (
|
||||
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
||||
if roots is None
|
||||
else [
|
||||
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||
]
|
||||
)
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
|
||||
dst_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=aggr_repo_id,
|
||||
fps=fps,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
root=aggr_root,
|
||||
)
|
||||
|
||||
logging.info("Find all tasks")
|
||||
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
|
||||
dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
|
||||
|
||||
meta_idx = {"chunk": 0, "file": 0}
|
||||
data_idx = {"chunk": 0, "file": 0}
|
||||
videos_idx = {
|
||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
||||
}
|
||||
|
||||
dst_meta.episodes = {}
|
||||
|
||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
|
||||
finalize_aggregation(dst_meta, all_metadata)
|
||||
logging.info("Aggregation complete.")
|
||||
|
||||
|
||||
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
|
||||
"""Aggregates video chunks from a source dataset into the destination dataset.
|
||||
|
||||
Handles video file concatenation and rotation based on file size limits.
|
||||
Creates new video files when size limits are exceeded.
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
videos_idx: Dictionary tracking video chunk and file indices.
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
for key, video_idx in videos_idx.items():
|
||||
unique_chunk_file_pairs = {
|
||||
(chunk, file)
|
||||
for chunk, file in zip(
|
||||
src_meta.episodes[f"videos/{key}/chunk_index"],
|
||||
src_meta.episodes[f"videos/{key}/file_index"],
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
unique_chunk_file_pairs = sorted(unique_chunk_file_pairs)
|
||||
|
||||
chunk_idx = video_idx["chunk"]
|
||||
file_idx = video_idx["file"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
file_index=src_file_idx,
|
||||
)
|
||||
|
||||
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
# If a new file is created, we don't want to increment the latest_duration
|
||||
update_latest_duration = False
|
||||
|
||||
if not dst_path.exists():
|
||||
# First write to this destination file
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
continue # not accumulating further, already copied the file in place
|
||||
|
||||
# Check file sizes before appending
|
||||
src_size = get_video_size_in_mb(src_path)
|
||||
dst_size = get_video_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= video_files_size_in_mb:
|
||||
# Rotate to a new chunk/file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
else:
|
||||
# Get the timestamps shift for this video
|
||||
timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"]
|
||||
|
||||
# Append to existing video file
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
# Update the latest_duration when appending (shifts timestamps!)
|
||||
update_latest_duration = not update_latest_duration
|
||||
|
||||
# Update the videos_idx with the final chunk and file indices for this key
|
||||
videos_idx[key]["chunk"] = chunk_idx
|
||||
videos_idx[key]["file"] = file_idx
|
||||
|
||||
if update_latest_duration:
|
||||
videos_idx[key]["latest_duration"] += timestamps_shift_s
|
||||
|
||||
return videos_idx
|
||||
|
||||
|
||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
|
||||
"""Aggregates data chunks from a source dataset into the destination dataset.
|
||||
|
||||
Reads source data files, updates indices to match the aggregated dataset,
|
||||
and writes them to the destination with proper file rotation.
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
|
||||
Returns:
|
||||
dict: Updated data_idx with current chunk and file indices.
|
||||
"""
|
||||
unique_chunk_file_ids = {
|
||||
(c, f)
|
||||
for c, f in zip(
|
||||
src_meta.episodes["data/chunk_index"], src_meta.episodes["data/file_index"], strict=False
|
||||
)
|
||||
}
|
||||
|
||||
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||
)
|
||||
df = pd.read_parquet(src_path)
|
||||
df = update_data_df(df, src_meta, dst_meta)
|
||||
|
||||
data_idx = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
data_idx,
|
||||
data_files_size_in_mb,
|
||||
chunk_size,
|
||||
DEFAULT_DATA_PATH,
|
||||
contains_images=len(dst_meta.image_keys) > 0,
|
||||
aggr_root=dst_meta.root,
|
||||
)
|
||||
|
||||
return data_idx
|
||||
|
||||
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
"""Aggregates metadata from a source dataset into the destination dataset.
|
||||
|
||||
Reads source metadata files, updates all indices and timestamps,
|
||||
and writes them to the destination with proper file rotation.
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
meta_idx: Dictionary tracking metadata chunk and file indices.
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
videos_idx: Dictionary tracking video indices and timestamps.
|
||||
|
||||
Returns:
|
||||
dict: Updated meta_idx with current chunk and file indices.
|
||||
"""
|
||||
chunk_file_ids = {
|
||||
(c, f)
|
||||
for c, f in zip(
|
||||
src_meta.episodes["meta/episodes/chunk_index"],
|
||||
src_meta.episodes["meta/episodes/file_index"],
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
|
||||
chunk_file_ids = sorted(chunk_file_ids)
|
||||
for chunk_idx, file_idx in chunk_file_ids:
|
||||
src_path = src_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
df = pd.read_parquet(src_path)
|
||||
df = update_meta_data(
|
||||
df,
|
||||
dst_meta,
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
)
|
||||
|
||||
for k in videos_idx:
|
||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
meta_idx,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
contains_images=False,
|
||||
aggr_root=dst_meta.root,
|
||||
)
|
||||
|
||||
return meta_idx
|
||||
|
||||
|
||||
def append_or_create_parquet_file(
|
||||
df: pd.DataFrame,
|
||||
src_path: Path,
|
||||
idx: dict[str, int],
|
||||
max_mb: float,
|
||||
chunk_size: int,
|
||||
default_path: str,
|
||||
contains_images: bool = False,
|
||||
aggr_root: Path = None,
|
||||
):
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
Manages file rotation when size limits are exceeded to prevent individual files
|
||||
from becoming too large. Handles both regular parquet files and those containing images.
|
||||
|
||||
Args:
|
||||
df: DataFrame to write to the parquet file.
|
||||
src_path: Path to the source file (used for size estimation).
|
||||
idx: Dictionary containing current 'chunk' and 'file' indices.
|
||||
max_mb: Maximum allowed file size in MB before rotation.
|
||||
chunk_size: Maximum number of files per chunk before incrementing chunk index.
|
||||
default_path: Format string for generating file paths.
|
||||
contains_images: Whether the data contains images requiring special handling.
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
|
||||
Returns:
|
||||
dict: Updated index dictionary with current chunk and file indices.
|
||||
"""
|
||||
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
|
||||
if not dst_path.exists():
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(df, dst_path)
|
||||
else:
|
||||
df.to_parquet(dst_path)
|
||||
return idx
|
||||
|
||||
src_size = get_parquet_file_size_in_mb(src_path)
|
||||
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= max_mb:
|
||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
final_df = df
|
||||
target_path = new_path
|
||||
else:
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
final_df = pd.concat([existing_df, df], ignore_index=True)
|
||||
target_path = dst_path
|
||||
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(final_df, target_path)
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
|
||||
return idx
|
||||
|
||||
|
||||
def finalize_aggregation(aggr_meta, all_metadata):
|
||||
"""Finalizes the dataset aggregation by writing summary files and statistics.
|
||||
|
||||
Writes the tasks file, info file with total counts and splits, and
|
||||
aggregated statistics from all source datasets.
|
||||
|
||||
Args:
|
||||
aggr_meta: Aggregated dataset metadata.
|
||||
all_metadata: List of all source dataset metadata objects.
|
||||
"""
|
||||
logging.info("write tasks")
|
||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||
|
||||
logging.info("write info")
|
||||
aggr_meta.info.update(
|
||||
{
|
||||
"total_tasks": len(aggr_meta.tasks),
|
||||
"total_episodes": sum(m.total_episodes for m in all_metadata),
|
||||
"total_frames": sum(m.total_frames for m in all_metadata),
|
||||
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
|
||||
}
|
||||
)
|
||||
write_info(aggr_meta.info, aggr_meta.root)
|
||||
|
||||
logging.info("write stats")
|
||||
aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata])
|
||||
write_stats(aggr_meta.stats, aggr_meta.root)
|
||||
@@ -14,33 +14,13 @@
|
||||
|
||||
import packaging.version
|
||||
|
||||
V2_MESSAGE = """
|
||||
V30_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
We introduced a new format since v3.0 which is not backward compatible with v2.1.
|
||||
Please, update your dataset to the new format using this command:
|
||||
```
|
||||
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
|
||||
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
|
||||
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
|
||||
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
|
||||
sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
V21_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
||||
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
||||
```
|
||||
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id={repo_id}
|
||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
@@ -58,7 +38,12 @@ class CompatibilityError(Exception): ...
|
||||
|
||||
class BackwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
if version.major == 2 and version.minor == 1:
|
||||
message = V30_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
|
||||
IMAGENET_STATS = {
|
||||
@@ -87,15 +88,26 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
)
|
||||
if not cfg.dataset.streaming:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
)
|
||||
else:
|
||||
dataset = StreamingLeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
dataset = MultiLeRobotDataset(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -337,13 +337,11 @@ def compute_sampler_weights(
|
||||
if len(offline_dataset) > 0:
|
||||
offline_data_mask_indices = []
|
||||
for start_index, end_index in zip(
|
||||
offline_dataset.episode_data_index["from"],
|
||||
offline_dataset.episode_data_index["to"],
|
||||
offline_dataset.meta.episodes["dataset_from_index"],
|
||||
offline_dataset.meta.episodes["dataset_to_index"],
|
||||
strict=True,
|
||||
):
|
||||
offline_data_mask_indices.extend(
|
||||
range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
|
||||
)
|
||||
offline_data_mask_indices.extend(range(start_index, end_index - offline_drop_n_last_frames))
|
||||
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
|
||||
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
|
||||
weights.append(
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline
|
||||
|
||||
|
||||
def aggregate_pipeline_dataset_features(
|
||||
pipeline: DataProcessorPipeline,
|
||||
initial_features: dict[str, Any],
|
||||
*,
|
||||
use_videos: bool = True,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Aggregates the pipeline's features and returns a features dict ready for the dataset,
|
||||
filtered to only those keys matching any of the given patterns (for action/state only).
|
||||
|
||||
- `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...}
|
||||
- `use_videos`: whether to treat image features as video streams
|
||||
- `patterns`: regexes to filter action & state features; images are included
|
||||
whenever use_videos=True, regardless of patterns.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Gather everything the pipeline features specifies, seeded with hardware cams:
|
||||
all_features = pipeline.transform_features(initial_features)
|
||||
|
||||
# Helper to decide which action/state keys survive the `patterns` filter:
|
||||
def keep(key: str) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
|
||||
# Start with hardware dict, injecting initial cameras if videos are ON:
|
||||
hw: dict[str, dict[str, Any]] = {}
|
||||
if use_videos:
|
||||
cams = {
|
||||
name: shape
|
||||
for name, shape in initial_features.items()
|
||||
if isinstance(shape, tuple) and len(shape) == 3
|
||||
}
|
||||
if cams:
|
||||
hw["observation"] = dict(cams)
|
||||
|
||||
# Go over every feature from the pipeline and merge:
|
||||
for full_key, ty in all_features.items():
|
||||
if full_key.startswith(f"{ACTION}."):
|
||||
# action.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len(f"{ACTION}.") :]
|
||||
hw.setdefault(ACTION, {})[name] = ty
|
||||
|
||||
elif full_key.startswith(f"{OBS_STATE}."):
|
||||
# observation.state.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len(f"{OBS_STATE}.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
elif full_key.startswith(f"{OBS_IMAGES}."):
|
||||
# observation.images.<cam>
|
||||
# images obey ONLY the use_videos flag, not patterns
|
||||
if not use_videos:
|
||||
continue
|
||||
name = full_key[len(f"{OBS_IMAGES}.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
else:
|
||||
# anything else (e.g. policy-only features) is ignored here
|
||||
continue
|
||||
|
||||
out: dict[str, dict] = {}
|
||||
if ACTION in hw:
|
||||
out.update(hw_to_dataset_features(hw[ACTION], ACTION, use_videos))
|
||||
if "observation" in hw:
|
||||
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
|
||||
|
||||
return out
|
||||
@@ -21,7 +21,8 @@ import torch
|
||||
class EpisodeAwareSampler:
|
||||
def __init__(
|
||||
self,
|
||||
episode_data_index: dict,
|
||||
dataset_from_indices: list[int],
|
||||
dataset_to_indices: list[int],
|
||||
episode_indices_to_use: list | None = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
@@ -30,7 +31,8 @@ class EpisodeAwareSampler:
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
Args:
|
||||
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
|
||||
dataset_from_indices: List of indices containing the start of each episode in the dataset.
|
||||
dataset_to_indices: List of indices containing the end of each episode in the dataset.
|
||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||
Assumes that episodes are indexed from 0 to N-1.
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
@@ -39,12 +41,10 @@ class EpisodeAwareSampler:
|
||||
"""
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
|
||||
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
indices.extend(
|
||||
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
||||
)
|
||||
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
|
||||
@@ -0,0 +1,535 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Callable, Generator, Iterator
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
Backtrackable,
|
||||
LookAheadError,
|
||||
LookBackError,
|
||||
check_version_compatibility,
|
||||
find_float_index,
|
||||
get_delta_indices,
|
||||
is_float_in_list,
|
||||
item_to_torch,
|
||||
safe_shard,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
VideoDecoderCache,
|
||||
decode_video_frames_torchcodec,
|
||||
)
|
||||
|
||||
|
||||
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
"""LeRobotDataset with streaming capabilities.
|
||||
|
||||
This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed
|
||||
rather than loaded entirely into memory. This is especially useful for large datasets that may
|
||||
not fit in memory or when you want to quickly explore a dataset without downloading it completely.
|
||||
|
||||
The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent
|
||||
items, allowing us to access previous frames for delta timestamps without loading the entire
|
||||
dataset into memory.
|
||||
|
||||
Example:
|
||||
Basic usage:
|
||||
```python
|
||||
from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
# Create a streaming dataset with delta timestamps
|
||||
delta_timestamps = {
|
||||
"observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current
|
||||
"action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future
|
||||
}
|
||||
|
||||
dataset = StreamingLeRobotDataset(
|
||||
repo_id="your-dataset-repo-id",
|
||||
delta_timestamps=delta_timestamps,
|
||||
streaming=True,
|
||||
buffer_size=1000,
|
||||
)
|
||||
|
||||
# Iterate over the dataset
|
||||
for i, item in enumerate(dataset):
|
||||
print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}")
|
||||
# item will contain stacked frames according to delta_timestamps
|
||||
if i >= 10:
|
||||
break
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
streaming: bool = True,
|
||||
buffer_size: int = 1000,
|
||||
max_num_shards: int = 16,
|
||||
seed: int = 42,
|
||||
rng: np.random.Generator | None = None,
|
||||
shuffle: bool = True,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
Args:
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list.
|
||||
image_transforms (Callable | None, optional): Transform to apply to image data.
|
||||
tolerance_s (float, optional): Tolerance in seconds for timestamp matching.
|
||||
revision (str, optional): Git revision id (branch name, tag, or commit hash).
|
||||
force_cache_sync (bool, optional): Flag to sync and refresh local files first.
|
||||
streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True.
|
||||
buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000.
|
||||
max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16.
|
||||
seed (int, optional): Reproducibility random seed.
|
||||
rng (np.random.Generator | None, optional): Random number generator.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.streaming_from_local = root is not None
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.seed = seed
|
||||
self.rng = rng if rng is not None else np.random.default_rng(seed)
|
||||
self.shuffle = shuffle
|
||||
|
||||
self.streaming = streaming
|
||||
self.buffer_size = buffer_size
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Load metadata
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
self.delta_timestamps = None
|
||||
self.delta_indices = None
|
||||
|
||||
if delta_timestamps is not None:
|
||||
self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
self.hf_dataset: datasets.IterableDataset = load_dataset(
|
||||
self.repo_id if not self.streaming_from_local else str(self.root),
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files="data/*/*.parquet",
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
|
||||
|
||||
@property
|
||||
def num_frames(self):
|
||||
return self.meta.total_frames
|
||||
|
||||
@property
|
||||
def num_episodes(self):
|
||||
return self.meta.total_episodes
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
return self.meta.fps
|
||||
|
||||
@staticmethod
|
||||
def _iter_random_indices(
|
||||
rng: np.random.Generator, buffer_size: int, random_batch_size=100
|
||||
) -> Iterator[int]:
|
||||
while True:
|
||||
yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))
|
||||
|
||||
@staticmethod
|
||||
def _infinite_generator_over_elements(rng: np.random.Generator, elements: list[int]) -> Iterator[int]:
|
||||
while True:
|
||||
yield rng.choice(elements)
|
||||
|
||||
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
|
||||
# The current sequential iteration is a bottleneck. A producer-consumer pattern
|
||||
# could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
|
||||
# in parallel, feeding a queue from which this iterator will yield processed items.
|
||||
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
||||
if self.video_decoder_cache is None:
|
||||
self.video_decoder_cache = VideoDecoderCache()
|
||||
|
||||
# keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
|
||||
rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
|
||||
|
||||
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
|
||||
|
||||
idx_to_backtrack_dataset = {
|
||||
idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
|
||||
for idx in range(self.num_shards)
|
||||
}
|
||||
|
||||
# This buffer is populated while iterating on the dataset's shards
|
||||
# the logic is to add 2 levels of randomness:
|
||||
# (1) sample one shard at random from the ones available, and
|
||||
# (2) sample one frame from the shard sampled at (1)
|
||||
frames_buffer = []
|
||||
while available_shards := list(idx_to_backtrack_dataset.keys()):
|
||||
shard_key = next(self._infinite_generator_over_elements(rng, available_shards))
|
||||
backtrack_dataset = idx_to_backtrack_dataset[shard_key] # selects which shard to iterate on
|
||||
|
||||
try:
|
||||
for frame in self.make_frame(backtrack_dataset):
|
||||
if len(frames_buffer) == self.buffer_size:
|
||||
i = next(buffer_indices_generator) # samples a element from the buffer
|
||||
yield frames_buffer[i]
|
||||
frames_buffer[i] = frame
|
||||
else:
|
||||
frames_buffer.append(frame)
|
||||
break # random shard sampled, switch shard
|
||||
except (
|
||||
RuntimeError,
|
||||
StopIteration,
|
||||
): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
|
||||
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
|
||||
|
||||
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
||||
rng.shuffle(frames_buffer)
|
||||
yield from frames_buffer
|
||||
|
||||
def _get_window_steps(
|
||||
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
|
||||
) -> tuple[int, int]:
|
||||
if delta_timestamps is None:
|
||||
return 1, 1
|
||||
|
||||
if not dynamic_bounds:
|
||||
# Fix the windows
|
||||
lookback = LOOKBACK_BACKTRACKTABLE
|
||||
lookahead = LOOKAHEAD_BACKTRACKTABLE
|
||||
else:
|
||||
# Dynamically adjust the windows based on the given delta_timesteps
|
||||
all_timestamps = sum(delta_timestamps.values(), [])
|
||||
lookback = min(all_timestamps) * self.fps
|
||||
lookahead = max(all_timestamps) * self.fps
|
||||
|
||||
# When lookback is >=0 it means no negative timesteps have been provided
|
||||
lookback = 0 if lookback >= 0 else (lookback * -1)
|
||||
|
||||
return lookback, lookahead
|
||||
|
||||
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
|
||||
lookback, lookahead = self._get_window_steps(self.delta_timestamps)
|
||||
return Backtrackable(dataset, history=lookback, lookahead=lookahead)
|
||||
|
||||
def _make_timestamps_from_indices(
|
||||
self, start_ts: float, indices: dict[str, list[int]] | None = None
|
||||
) -> dict[str, list[float]]:
|
||||
if indices is not None:
|
||||
return {
|
||||
key: (
|
||||
start_ts + torch.tensor(indices[key]) / self.fps
|
||||
).tolist() # NOTE: why not delta_timestamps directly?
|
||||
for key in self.delta_timestamps
|
||||
}
|
||||
else:
|
||||
return dict.fromkeys(self.meta.video_keys, [start_ts])
|
||||
|
||||
def _make_padding_camera_frame(self, camera_key: str):
|
||||
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
|
||||
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
|
||||
|
||||
def _get_video_frame_padding_mask(
|
||||
self,
|
||||
video_frames: dict[str, torch.Tensor],
|
||||
query_timestamps: dict[str, list[float]],
|
||||
original_timestamps: dict[str, list[float]],
|
||||
) -> dict[str, torch.BoolTensor]:
|
||||
padding_mask = {}
|
||||
|
||||
for video_key, timestamps in original_timestamps.items():
|
||||
if video_key not in video_frames:
|
||||
continue # only padding on video keys that are available
|
||||
frames = []
|
||||
mask = []
|
||||
padding_frame = self._make_padding_camera_frame(video_key)
|
||||
for ts in timestamps:
|
||||
if is_float_in_list(ts, query_timestamps[video_key]):
|
||||
idx = find_float_index(ts, query_timestamps[video_key])
|
||||
frames.append(video_frames[video_key][idx, :])
|
||||
mask.append(False)
|
||||
else:
|
||||
frames.append(padding_frame)
|
||||
mask.append(True)
|
||||
|
||||
padding_mask[f"{video_key}_is_pad"] = torch.BoolTensor(mask)
|
||||
|
||||
return padding_mask
|
||||
|
||||
def make_frame(
|
||||
self, dataset_iterator: Backtrackable, previous_dataset_iterator: Backtrackable | None = None
|
||||
) -> Generator:
|
||||
"""Makes a frame starting from a dataset iterator"""
|
||||
item = next(dataset_iterator)
|
||||
item = item_to_torch(item)
|
||||
|
||||
updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features)
|
||||
|
||||
# Get episode index from the item
|
||||
ep_idx = item["episode_index"]
|
||||
|
||||
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
|
||||
current_ts = item["index"] / self.fps
|
||||
|
||||
episode_boundaries_ts = {
|
||||
key: (
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
|
||||
)
|
||||
for key in self.meta.video_keys
|
||||
}
|
||||
|
||||
# Apply delta querying logic if necessary
|
||||
if self.delta_indices is not None:
|
||||
query_result, padding = self._get_delta_frames(dataset_iterator, item)
|
||||
updates.append(query_result)
|
||||
updates.append(padding)
|
||||
|
||||
# Load video frames, when needed
|
||||
if len(self.meta.video_keys) > 0:
|
||||
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
|
||||
|
||||
# Some timestamps might not result available considering the episode's boundaries
|
||||
query_timestamps = self._get_query_timestamps(
|
||||
current_ts, self.delta_indices, episode_boundaries_ts
|
||||
)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
video_frames[cam] = self.image_transforms(video_frames[cam])
|
||||
|
||||
updates.append(video_frames)
|
||||
|
||||
if self.delta_indices is not None:
|
||||
# We always return the same number of frames. Unavailable frames are padded.
|
||||
padding_mask = self._get_video_frame_padding_mask(
|
||||
video_frames, query_timestamps, original_timestamps
|
||||
)
|
||||
updates.append(padding_mask)
|
||||
|
||||
result = item.copy()
|
||||
for update in updates:
|
||||
result.update(update)
|
||||
|
||||
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||
|
||||
yield result
|
||||
|
||||
def _get_query_timestamps(
|
||||
self,
|
||||
current_ts: float,
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
episode_boundaries_ts: dict[str, tuple[float, float]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
keys_to_timestamps = self._make_timestamps_from_indices(current_ts, query_indices)
|
||||
for key in self.meta.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = keys_to_timestamps[key]
|
||||
# Clamp out timesteps outside of episode boundaries
|
||||
query_timestamps[key] = torch.clamp(
|
||||
torch.tensor(timestamps), *episode_boundaries_ts[key]
|
||||
).tolist()
|
||||
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
|
||||
return query_timestamps
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
the main process and a subprocess fails to access it.
|
||||
"""
|
||||
|
||||
item = {}
|
||||
for video_key, query_ts in query_timestamps.items():
|
||||
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
|
||||
)
|
||||
|
||||
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
|
||||
|
||||
return item
|
||||
|
||||
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
|
||||
# TODO(fracapuano): Modularize this function, refactor the code
|
||||
"""Get frames with delta offsets using the backtrackable iterator.
|
||||
|
||||
Args:
|
||||
current_item (dict): Current item from the iterator.
|
||||
ep_idx (int): Episode index.
|
||||
|
||||
Returns:
|
||||
tuple: (query_result, padding) - frames at delta offsets and padding info.
|
||||
"""
|
||||
current_episode_idx = current_item["episode_index"]
|
||||
|
||||
# Prepare results
|
||||
query_result = {}
|
||||
padding = {}
|
||||
|
||||
for key, delta_indices in self.delta_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
continue # visual frames are decoded separately
|
||||
|
||||
target_frames = []
|
||||
is_pad = []
|
||||
|
||||
# Create a results dictionary to store frames in processing order, then reconstruct original order for stacking
|
||||
delta_results = {}
|
||||
|
||||
# Separate and sort deltas by difficulty (easier operations first)
|
||||
negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...]
|
||||
positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...]
|
||||
zero_deltas = [d for d in delta_indices if d == 0]
|
||||
|
||||
# Process zero deltas (current frame)
|
||||
for delta in zero_deltas:
|
||||
delta_results[delta] = (
|
||||
current_item[key],
|
||||
False,
|
||||
)
|
||||
|
||||
# Process negative deltas in order of increasing difficulty
|
||||
lookback_failed = False
|
||||
|
||||
last_successful_frame = current_item[key]
|
||||
|
||||
for delta in negative_deltas:
|
||||
if lookback_failed:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
continue
|
||||
|
||||
try:
|
||||
steps_back = abs(delta)
|
||||
if dataset_iterator.can_peek_back(steps_back):
|
||||
past_item = dataset_iterator.peek_back(steps_back)
|
||||
past_item = item_to_torch(past_item)
|
||||
|
||||
if past_item["episode_index"] == current_episode_idx:
|
||||
delta_results[delta] = (past_item[key], False)
|
||||
last_successful_frame = past_item[key]
|
||||
|
||||
else:
|
||||
raise LookBackError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookBackError("Cannot go back further than the history buffer!")
|
||||
|
||||
except LookBackError:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
lookback_failed = True # All subsequent negative deltas will also fail
|
||||
|
||||
# Process positive deltas in order of increasing difficulty
|
||||
lookahead_failed = False
|
||||
last_successful_frame = current_item[key]
|
||||
|
||||
for delta in positive_deltas:
|
||||
if lookahead_failed:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
continue
|
||||
|
||||
try:
|
||||
if dataset_iterator.can_peek_ahead(delta):
|
||||
future_item = dataset_iterator.peek_ahead(delta)
|
||||
future_item = item_to_torch(future_item)
|
||||
|
||||
if future_item["episode_index"] == current_episode_idx:
|
||||
delta_results[delta] = (future_item[key], False)
|
||||
last_successful_frame = future_item[key]
|
||||
|
||||
else:
|
||||
raise LookAheadError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
|
||||
|
||||
except LookAheadError:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
lookahead_failed = True # All subsequent positive deltas will also fail
|
||||
|
||||
# Reconstruct original order for stacking
|
||||
for delta in delta_indices:
|
||||
frame, is_padded = delta_results[delta]
|
||||
|
||||
# add batch dimension for stacking
|
||||
target_frames.append(frame) # frame.unsqueeze(0))
|
||||
is_pad.append(is_padded)
|
||||
|
||||
# Stack frames and add to results
|
||||
if target_frames:
|
||||
query_result[key] = torch.stack(target_frames)
|
||||
padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad)
|
||||
|
||||
return query_result, padding
|
||||
|
||||
def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None:
|
||||
"""
|
||||
Validate that all keys in delta_timestamps correspond to actual features in the dataset.
|
||||
|
||||
Raises:
|
||||
ValueError: If any delta timestamp key doesn't correspond to a dataset feature.
|
||||
"""
|
||||
if delta_timestamps is None:
|
||||
return
|
||||
|
||||
# Get all available feature keys from the dataset metadata
|
||||
available_features = set(self.meta.features.keys())
|
||||
|
||||
# Get all keys from delta_timestamps
|
||||
delta_keys = set(delta_timestamps.keys())
|
||||
|
||||
# Find any keys that don't correspond to features
|
||||
invalid_keys = delta_keys - available_features
|
||||
|
||||
if invalid_keys:
|
||||
raise ValueError(
|
||||
f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. "
|
||||
f"Available features are: {sorted(available_features)}"
|
||||
)
|
||||
+398
-283
@@ -17,43 +17,57 @@ import contextlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from itertools import accumulate
|
||||
from collections import deque
|
||||
from collections.abc import Iterable, Iterator
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from typing import Any, Deque, Generic, TypeVar
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import pandas
|
||||
import pandas as pd
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset, concatenate_datasets
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.backward_compatibility import (
|
||||
V21_MESSAGE,
|
||||
FUTURE_MESSAGE,
|
||||
BackwardCompatibilityError,
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 500 # Max size per file
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
EPISODES_PATH = "meta/episodes.jsonl"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
EPISODES_DIR = "meta/episodes"
|
||||
DATA_DIR = "data"
|
||||
VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
||||
LEGACY_DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
@@ -73,6 +87,67 @@ DEFAULT_FEATURES = {
|
||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
total_uncompressed_size = 0
|
||||
for row_group in range(metadata.num_row_groups):
|
||||
rg_metadata = metadata.row_group(row_group)
|
||||
for column in range(rg_metadata.num_columns):
|
||||
col_metadata = rg_metadata.column(column)
|
||||
total_uncompressed_size += col_metadata.total_uncompressed_size
|
||||
return total_uncompressed_size / (1024**2)
|
||||
|
||||
|
||||
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
|
||||
return hf_ds.data.nbytes // (1024**2)
|
||||
|
||||
|
||||
def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None:
|
||||
if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0:
|
||||
return None
|
||||
return Path(hf_ds.cache_files[0]["filename"]).parents[2]
|
||||
|
||||
|
||||
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
|
||||
if file_idx == chunks_size - 1:
|
||||
file_idx = 0
|
||||
chunk_idx += 1
|
||||
else:
|
||||
file_idx += 1
|
||||
return chunk_idx, file_idx
|
||||
|
||||
|
||||
def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
|
||||
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
||||
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
||||
Concatenate all pyarrow references to return HF Dataset format
|
||||
|
||||
Args:
|
||||
pq_dir: Directory containing parquet files
|
||||
features: Optional features schema to ensure consistent loading of complex types like images
|
||||
"""
|
||||
paths = sorted(pq_dir.glob("*/*.parquet"))
|
||||
if len(paths) == 0:
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||
datasets = [Dataset.from_parquet(str(path), features=features) for path in paths]
|
||||
return concatenate_datasets(datasets)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
return metadata.num_rows
|
||||
|
||||
|
||||
def get_video_size_in_mb(mp4_path: Path) -> float:
|
||||
file_size_bytes = mp4_path.stat().st_size
|
||||
file_size_mb = file_size_bytes / (1024**2)
|
||||
return file_size_mb
|
||||
|
||||
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
@@ -82,6 +157,7 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
|
||||
>>> print(flatten_dict(dct))
|
||||
{"a/b": 1, "a/c/d": 2, "e": 3}
|
||||
```
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
@@ -106,23 +182,13 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
return outdict
|
||||
|
||||
|
||||
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
||||
split_keys = flattened_key.split(sep)
|
||||
getter = obj[split_keys[0]]
|
||||
if len(split_keys) == 1:
|
||||
return getter
|
||||
|
||||
for key in split_keys[1:]:
|
||||
getter = getter[key]
|
||||
|
||||
return getter
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
serialized_dict = {}
|
||||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
serialized_dict[key] = value.tolist()
|
||||
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
|
||||
serialized_dict[key] = value
|
||||
elif isinstance(value, np.generic):
|
||||
serialized_dict[key] = value.item()
|
||||
elif isinstance(value, (int, float)):
|
||||
@@ -152,24 +218,7 @@ def write_json(data: dict, fpath: Path) -> None:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def load_jsonlines(fpath: Path) -> list[Any]:
|
||||
with jsonlines.open(fpath, "r") as reader:
|
||||
return list(reader)
|
||||
|
||||
|
||||
def write_jsonlines(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(data)
|
||||
|
||||
|
||||
def append_jsonlines(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "a") as writer:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path):
|
||||
def write_info(info: dict, local_dir: Path) -> None:
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
|
||||
@@ -180,65 +229,68 @@ def load_info(local_dir: Path) -> dict:
|
||||
return info
|
||||
|
||||
|
||||
def write_stats(stats: dict, local_dir: Path):
|
||||
def write_stats(stats: dict, local_dir: Path) -> None:
|
||||
serialized_stats = serialize_dict(stats)
|
||||
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||
|
||||
|
||||
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
||||
def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
return cast_stats_to_numpy(stats)
|
||||
|
||||
|
||||
def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, local_dir / TASKS_PATH)
|
||||
def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
|
||||
path = local_dir / DEFAULT_TASKS_PATH
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tasks.to_parquet(path)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
return tasks, task_to_task_index
|
||||
def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
return tasks
|
||||
|
||||
|
||||
def write_episode(episode: dict, local_dir: Path):
|
||||
append_jsonlines(episode, local_dir / EPISODES_PATH)
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
Used primarily during dataset conversion (v2.1 → v3.0) and in test fixtures.
|
||||
|
||||
Args:
|
||||
episodes: HuggingFace Dataset containing episode metadata
|
||||
local_dir: Root directory where the dataset will be stored
|
||||
"""
|
||||
episode_size_mb = get_hf_dataset_size_in_mb(episodes)
|
||||
if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||
raise NotImplementedError(
|
||||
f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. "
|
||||
f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. "
|
||||
"This function only supports single-file episode metadata. "
|
||||
)
|
||||
|
||||
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes.to_parquet(fpath)
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
|
||||
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||
# is a dictionary of stats and not an integer.
|
||||
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
||||
|
||||
|
||||
def load_episodes_stats(local_dir: Path) -> dict:
|
||||
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
|
||||
return {
|
||||
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
||||
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||
}
|
||||
def load_episodes(local_dir: Path) -> datasets.Dataset:
|
||||
episodes = load_nested_dataset(local_dir / EPISODES_DIR)
|
||||
# Select episode features/columns containing references to episode data and videos
|
||||
# (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.)
|
||||
# This is to speedup access to these data, instead of having to load episode stats.
|
||||
episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")])
|
||||
return episodes
|
||||
|
||||
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
) -> dict[int, dict[str, dict[str, np.ndarray]]]:
|
||||
return dict.fromkeys(episodes, stats)
|
||||
|
||||
|
||||
@@ -254,7 +306,7 @@ def load_image_as_numpy(
|
||||
return img_array
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||
@@ -299,7 +351,7 @@ def check_version_compatibility(
|
||||
if v_check.major < v_current.major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, v_check)
|
||||
elif v_check.minor < v_current.minor:
|
||||
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||
logging.warning(FUTURE_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||
|
||||
|
||||
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||
@@ -470,56 +522,15 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
return policy_features
|
||||
|
||||
|
||||
def combine_feature_dicts(*dicts: dict) -> dict:
|
||||
"""
|
||||
Merge LeRobot grouped feature dicts.
|
||||
|
||||
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
|
||||
- For others (observation.images.*), last one wins (if they are identical).
|
||||
"""
|
||||
out: dict = {}
|
||||
for d in dicts:
|
||||
for key, value in d.items():
|
||||
if not isinstance(value, dict):
|
||||
out[key] = value
|
||||
continue
|
||||
|
||||
dtype = value.get("dtype")
|
||||
shape = value.get("shape")
|
||||
is_vector = (
|
||||
dtype not in ("image", "video", "string")
|
||||
and isinstance(shape, tuple)
|
||||
and len(shape) == 1
|
||||
and "names" in value
|
||||
)
|
||||
|
||||
if is_vector:
|
||||
# Initialize or retrieve the accumulating dict for this feature key
|
||||
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
|
||||
# Ensure consistent data types across merged entries
|
||||
if "dtype" in target and dtype != target["dtype"]:
|
||||
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
|
||||
|
||||
# Merge feature names: append only new ones to preserve order without duplicates
|
||||
seen = set(target["names"])
|
||||
for n in value["names"]:
|
||||
if n not in seen:
|
||||
target["names"].append(n)
|
||||
seen.add(n)
|
||||
# Recompute the shape to reflect the updated number of features
|
||||
target["shape"] = (len(target["names"]),)
|
||||
else:
|
||||
# For images/videos and non-1D entries: override with the latest definition
|
||||
out[key] = value
|
||||
return out
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
features: dict,
|
||||
use_videos: bool,
|
||||
robot_type: str | None = None,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> dict:
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
@@ -527,104 +538,17 @@ def create_empty_dataset_info(
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"total_videos": 0,
|
||||
"total_chunks": 0,
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
||||
def get_episode_data_index(
|
||||
episode_dicts: dict[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
cumulative_lengths = list(accumulate(episode_lengths.values()))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lengths[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lengths),
|
||||
}
|
||||
|
||||
|
||||
def check_timestamps_sync(
|
||||
timestamps: np.ndarray,
|
||||
episode_indices: np.ndarray,
|
||||
episode_data_index: dict[str, np.ndarray],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
|
||||
to account for possible numerical error.
|
||||
|
||||
Args:
|
||||
timestamps (np.ndarray): Array of timestamps in seconds.
|
||||
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
|
||||
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
|
||||
which identifies indices for the end of each episode.
|
||||
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
|
||||
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
|
||||
raise_value_error (bool): Whether to raise a ValueError if the check fails.
|
||||
|
||||
Returns:
|
||||
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If the check fails and `raise_value_error` is True.
|
||||
"""
|
||||
if timestamps.shape != episode_indices.shape:
|
||||
raise ValueError(
|
||||
"timestamps and episode_indices should have the same shape. "
|
||||
f"Found {timestamps.shape=} and {episode_indices.shape=}."
|
||||
)
|
||||
|
||||
# Consecutive differences
|
||||
diffs = np.diff(timestamps)
|
||||
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
|
||||
|
||||
# Mask to ignore differences at the boundaries between episodes
|
||||
mask = np.ones(len(diffs), dtype=bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
|
||||
mask[ignored_diffs] = False
|
||||
filtered_within_tolerance = within_tolerance[mask]
|
||||
|
||||
# Check if all remaining diffs are within tolerance
|
||||
if not np.all(filtered_within_tolerance):
|
||||
# Track original indices before masking
|
||||
original_indices = np.arange(len(diffs))
|
||||
filtered_indices = original_indices[mask]
|
||||
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
|
||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||
|
||||
outside_tolerances = []
|
||||
for idx in outside_tolerance_indices:
|
||||
entry = {
|
||||
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
||||
"diff": diffs[idx],
|
||||
"episode_index": episode_indices[idx].item()
|
||||
if hasattr(episode_indices[idx], "item")
|
||||
else episode_indices[idx],
|
||||
}
|
||||
outside_tolerances.append(entry)
|
||||
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
||||
This might be due to synchronization issues during data collection.
|
||||
\n{pformat(outside_tolerances)}"""
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
@@ -663,7 +587,7 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
|
||||
return delta_indices
|
||||
|
||||
|
||||
def cycle(iterable):
|
||||
def cycle(iterable: Any) -> Iterator[Any]:
|
||||
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
||||
|
||||
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
|
||||
@@ -676,7 +600,7 @@ def cycle(iterable):
|
||||
iterator = iter(iterable)
|
||||
|
||||
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
|
||||
def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) -> None:
|
||||
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
||||
exists before creating it.
|
||||
"""
|
||||
@@ -729,76 +653,28 @@ def create_lerobot_dataset_card(
|
||||
)
|
||||
|
||||
|
||||
class IterableNamespace(SimpleNamespace):
|
||||
"""
|
||||
A namespace object that supports both dictionary-like iteration and dot notation access.
|
||||
Automatically converts nested dictionaries into IterableNamespaces.
|
||||
|
||||
This class extends SimpleNamespace to provide:
|
||||
- Dictionary-style iteration over keys
|
||||
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
|
||||
- Dictionary-like methods: items(), keys(), values()
|
||||
- Recursive conversion of nested dictionaries
|
||||
|
||||
Args:
|
||||
dictionary: Optional dictionary to initialize the namespace
|
||||
**kwargs: Additional keyword arguments passed to SimpleNamespace
|
||||
|
||||
Examples:
|
||||
>>> data = {"name": "Alice", "details": {"age": 25}}
|
||||
>>> ns = IterableNamespace(data)
|
||||
>>> ns.name
|
||||
'Alice'
|
||||
>>> ns.details.age
|
||||
25
|
||||
>>> list(ns.keys())
|
||||
['name', 'details']
|
||||
>>> for key, value in ns.items():
|
||||
... print(f"{key}: {value}")
|
||||
name: Alice
|
||||
details: IterableNamespace(age=25)
|
||||
"""
|
||||
|
||||
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if dictionary is not None:
|
||||
for key, value in dictionary.items():
|
||||
if isinstance(value, dict):
|
||||
setattr(self, key, IterableNamespace(value))
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(vars(self))
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return vars(self)[key]
|
||||
|
||||
def items(self):
|
||||
return vars(self).items()
|
||||
|
||||
def values(self):
|
||||
return vars(self).values()
|
||||
|
||||
def keys(self):
|
||||
return vars(self).keys()
|
||||
|
||||
|
||||
def validate_frame(frame: dict, features: dict):
|
||||
def validate_frame(frame: dict, features: dict) -> None:
|
||||
expected_features = set(features) - set(DEFAULT_FEATURES)
|
||||
actual_features = set(frame)
|
||||
|
||||
error_message = validate_features_presence(actual_features, expected_features)
|
||||
# task is a special required field that's not part of regular features
|
||||
if "task" not in actual_features:
|
||||
raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n")
|
||||
|
||||
common_features = actual_features & expected_features
|
||||
for name in common_features - {"task"}:
|
||||
# Remove task from actual_features for regular feature validation
|
||||
actual_features_for_validation = actual_features - {"task"}
|
||||
|
||||
error_message = validate_features_presence(actual_features_for_validation, expected_features)
|
||||
|
||||
common_features = actual_features_for_validation & expected_features
|
||||
for name in common_features:
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def validate_features_presence(actual_features: set[str], expected_features: set[str]):
|
||||
def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str:
|
||||
error_message = ""
|
||||
missing_features = expected_features - actual_features
|
||||
extra_features = actual_features - expected_features
|
||||
@@ -813,7 +689,9 @@ def validate_features_presence(actual_features: set[str], expected_features: set
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
def validate_feature_dtype_and_shape(
|
||||
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
|
||||
) -> str:
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
@@ -828,7 +706,7 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
|
||||
|
||||
def validate_feature_numpy_array(
|
||||
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
||||
):
|
||||
) -> str:
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_dtype = value.dtype
|
||||
@@ -845,7 +723,9 @@ def validate_feature_numpy_array(
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
def validate_feature_image_or_video(
|
||||
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
|
||||
) -> str:
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
@@ -861,13 +741,13 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str):
|
||||
def validate_feature_string(name: str, value: str) -> str:
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
|
||||
if "size" not in episode_buffer:
|
||||
raise ValueError("size key not found in episode_buffer")
|
||||
|
||||
@@ -891,3 +771,238 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
|
||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
"""Convert all items in a dictionary to PyTorch tensors where appropriate.
|
||||
|
||||
This function is used to convert an item from a streaming dataset to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
item (dict): Dictionary of items from a dataset.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
|
||||
def is_float_in_list(target, float_list, threshold=1e-6):
|
||||
return any(abs(target - x) <= threshold for x in float_list)
|
||||
|
||||
|
||||
def find_float_index(target, float_list, threshold=1e-6):
|
||||
for i, x in enumerate(float_list):
|
||||
if abs(target - x) <= threshold:
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
class LookBackError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look back in the history of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LookAheadError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look ahead in the future of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Backtrackable(Generic[T]):
|
||||
"""
|
||||
Wrap any iterator/iterable so you can step back up to `history` items
|
||||
and look ahead up to `lookahead` items.
|
||||
|
||||
This is useful for streaming datasets where you need to access previous and future items
|
||||
but can't load the entire dataset into memory.
|
||||
|
||||
Example:
|
||||
-------
|
||||
```python
|
||||
ds = load_dataset("c4", "en", streaming=True, split="train")
|
||||
rev = Backtrackable(ds, history=3, lookahead=2)
|
||||
|
||||
x0 = next(rev) # forward
|
||||
x1 = next(rev)
|
||||
x2 = next(rev)
|
||||
|
||||
# Look ahead
|
||||
x3_peek = rev.peek_ahead(1) # next item without moving cursor
|
||||
x4_peek = rev.peek_ahead(2) # two items ahead
|
||||
|
||||
# Look back
|
||||
x1_again = rev.peek_back(1) # previous item without moving cursor
|
||||
x0_again = rev.peek_back(2) # two items back
|
||||
|
||||
# Move backward
|
||||
x1_back = rev.prev() # back one step
|
||||
next(rev) # returns x2, continues forward from where we were
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead")
|
||||
|
||||
def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0):
|
||||
if history < 1:
|
||||
raise ValueError("history must be >= 1")
|
||||
if lookahead <= 0:
|
||||
raise ValueError("lookahead must be > 0")
|
||||
|
||||
self._source: Iterator[T] = iter(iterable)
|
||||
self._back_buf: Deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._cursor: int = 0
|
||||
self._history = history
|
||||
self._lookahead = lookahead
|
||||
|
||||
def __iter__(self) -> "Backtrackable[T]":
|
||||
return self
|
||||
|
||||
def __next__(self) -> T:
|
||||
# If we've stepped back, consume from back buffer first
|
||||
if self._cursor < 0: # -1 means "last item", etc.
|
||||
self._cursor += 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
# If we have items in the ahead buffer, use them first
|
||||
item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source)
|
||||
|
||||
# Add current item to back buffer and reset cursor
|
||||
self._back_buf.append(item)
|
||||
self._cursor = 0
|
||||
return item
|
||||
|
||||
def prev(self) -> T:
|
||||
"""
|
||||
Step one item back in history and return it.
|
||||
Raises IndexError if already at the oldest buffered item.
|
||||
"""
|
||||
if len(self._back_buf) + self._cursor <= 1:
|
||||
raise LookBackError("At start of history")
|
||||
|
||||
self._cursor -= 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
def peek_back(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items back (n=1 == previous item) without moving the cursor.
|
||||
"""
|
||||
if n < 0 or n + 1 > len(self._back_buf) + self._cursor:
|
||||
raise LookBackError("peek_back distance out of range")
|
||||
|
||||
return self._back_buf[self._cursor - (n + 1)]
|
||||
|
||||
def peek_ahead(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items ahead (n=1 == next item) without moving the cursor.
|
||||
Fills the ahead buffer if necessary.
|
||||
"""
|
||||
if n < 1:
|
||||
raise LookAheadError("peek_ahead distance must be 1 or more")
|
||||
elif n > self._lookahead:
|
||||
raise LookAheadError("peek_ahead distance exceeds lookahead limit")
|
||||
|
||||
# Fill ahead buffer if we don't have enough items
|
||||
while len(self._ahead_buf) < n:
|
||||
try:
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
|
||||
except StopIteration as err:
|
||||
raise LookAheadError("peek_ahead: not enough items in source") from err
|
||||
|
||||
return self._ahead_buf[n - 1]
|
||||
|
||||
def history(self) -> list[T]:
|
||||
"""
|
||||
Return a copy of the buffered history (most recent last).
|
||||
The list length ≤ `history` argument passed at construction.
|
||||
"""
|
||||
if self._cursor == 0:
|
||||
return list(self._back_buf)
|
||||
|
||||
# When cursor<0, slice so the order remains chronological
|
||||
return list(self._back_buf)[: self._cursor or None]
|
||||
|
||||
def lookahead_buffer(self) -> list[T]:
|
||||
"""
|
||||
Return a copy of the current lookahead buffer.
|
||||
"""
|
||||
return list(self._ahead_buf)
|
||||
|
||||
def can_peek_back(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can go back `steps` items without raising an IndexError.
|
||||
"""
|
||||
return steps <= len(self._back_buf) + self._cursor
|
||||
|
||||
def can_peek_ahead(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can peek ahead `steps` items.
|
||||
This may involve trying to fill the ahead buffer.
|
||||
"""
|
||||
if self._lookahead > 0 and steps > self._lookahead:
|
||||
return False
|
||||
|
||||
# Try to fill ahead buffer to check if we can peek that far
|
||||
try:
|
||||
while len(self._ahead_buf) < steps:
|
||||
if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead:
|
||||
return False
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
return True
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
def reset_cursor(self) -> None:
|
||||
"""
|
||||
Reset cursor to the most recent position (equivalent to calling next()
|
||||
until you're back to the latest item).
|
||||
"""
|
||||
self._cursor = 0
|
||||
|
||||
def clear_ahead_buffer(self) -> None:
|
||||
"""
|
||||
Clear the ahead buffer, discarding any pre-fetched items.
|
||||
"""
|
||||
self._ahead_buf.clear()
|
||||
|
||||
def switch_source_iterable(self, new_source: Iterable[T]) -> None:
|
||||
"""
|
||||
Switch the source of the backtrackable to a new iterable, keeping the history.
|
||||
|
||||
This is useful when iterating over a sequence of datasets. The history from the
|
||||
previous source is kept, but the lookahead buffer is cleared. The cursor is reset
|
||||
to the present.
|
||||
"""
|
||||
self._source = iter(new_source)
|
||||
self.clear_ahead_buffer()
|
||||
self.reset_cursor()
|
||||
|
||||
|
||||
def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset:
|
||||
"""
|
||||
Safe shards the dataset.
|
||||
"""
|
||||
shard_idx = min(dataset.num_shards, index + 1) - 1
|
||||
|
||||
return dataset.shard(num_shards, index=shard_idx)
|
||||
|
||||
@@ -1,884 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.
|
||||
|
||||
Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
|
||||
lerobot/configs/robot/aloha.yaml before running this script.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
|
||||
from lerobot.robots.aloha.configuration_aloha import AlohaRobotConfig
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
# spellchecker:off
|
||||
ALOHA_MOBILE_INFO = {
|
||||
"robot_config": AlohaRobotConfig(),
|
||||
"license": "mit",
|
||||
"url": "https://mobile-aloha.github.io/",
|
||||
"paper": "https://huggingface.co/papers/2401.02117",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{fu2024mobile,
|
||||
author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea},
|
||||
title = {Mobile ALOHA: Learning Bimanual Mobile Manipulation with Low-Cost Whole-Body Teleoperation},
|
||||
booktitle = {arXiv},
|
||||
year = {2024},
|
||||
}""").lstrip(),
|
||||
}
|
||||
ALOHA_STATIC_INFO = {
|
||||
"robot_config": AlohaRobotConfig(),
|
||||
"license": "mit",
|
||||
"url": "https://tonyzhaozh.github.io/aloha/",
|
||||
"paper": "https://huggingface.co/papers/2304.13705",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{Zhao2023LearningFB,
|
||||
title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
|
||||
author={Tony Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
|
||||
journal={RSS},
|
||||
year={2023},
|
||||
volume={abs/2304.13705},
|
||||
url={https://huggingface.co/papers/2304.13705}
|
||||
}""").lstrip(),
|
||||
}
|
||||
PUSHT_INFO = {
|
||||
"license": "mit",
|
||||
"url": "https://diffusion-policy.cs.columbia.edu/",
|
||||
"paper": "https://huggingface.co/papers/2303.04137",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{chi2024diffusionpolicy,
|
||||
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
|
||||
title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
||||
journal = {The International Journal of Robotics Research},
|
||||
year = {2024},
|
||||
}""").lstrip(),
|
||||
}
|
||||
XARM_INFO = {
|
||||
"license": "mit",
|
||||
"url": "https://www.nicklashansen.com/td-mpc/",
|
||||
"paper": "https://huggingface.co/papers/2203.04955",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{Hansen2022tdmpc,
|
||||
title={Temporal Difference Learning for Model Predictive Control},
|
||||
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
|
||||
booktitle={ICML},
|
||||
year={2022}
|
||||
}
|
||||
"""),
|
||||
}
|
||||
UNITREEH_INFO = {
|
||||
"license": "apache-2.0",
|
||||
}
|
||||
|
||||
DATASETS = {
|
||||
"aloha_mobile_cabinet": {
|
||||
"single_task": "Open the top cabinet, store the pot inside it then close the cabinet.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_chair": {
|
||||
"single_task": "Push the chairs in front of the desk to place them against it.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_elevator": {
|
||||
"single_task": "Take the elevator to the 1st floor.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_shrimp": {
|
||||
"single_task": "Sauté the raw shrimp on both sides, then serve it in the bowl.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_wash_pan": {
|
||||
"single_task": "Pick up the pan, rinse it in the sink and then place it in the drying rack.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_wipe_wine": {
|
||||
"single_task": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_static_battery": {
|
||||
"single_task": "Place the battery into the slot of the remote controller.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
|
||||
"aloha_static_coffee": {
|
||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_coffee_new": {
|
||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_cups_open": {
|
||||
"single_task": "Pick up the plastic cup and open its lid.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_fork_pick_up": {
|
||||
"single_task": "Pick up the fork and place it on the plate.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_pingpong_test": {
|
||||
"single_task": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_pro_pencil": {
|
||||
"single_task": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_screw_driver": {
|
||||
"single_task": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_tape": {
|
||||
"single_task": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_thread_velcro": {
|
||||
"single_task": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_towel": {
|
||||
"single_task": "Pick up a piece of paper towel and place it on the spilled liquid.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup": {
|
||||
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup_left": {
|
||||
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_scripted_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_human_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_scripted": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_scripted_image": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_human": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_human_image": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
|
||||
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
|
||||
"unitreeh1_two_robot_greeting": {
|
||||
"single_task": "Greet the other robot with a high five.",
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"unitreeh1_warehouse": {
|
||||
"single_task": "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.",
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"umi_cup_in_the_wild": {
|
||||
"single_task": "Put the cup on the plate.",
|
||||
"license": "apache-2.0",
|
||||
},
|
||||
"asu_table_top": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://link.springer.com/article/10.1007/s10514-023-10129-1",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{zhou2023modularity,
|
||||
title={Modularity through Attention: Efficient Training and Transfer of Language-Conditioned Policies for Robot Manipulation},
|
||||
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Stepputtis, Simon and Amor, Heni},
|
||||
booktitle={Conference on Robot Learning},
|
||||
pages={1684--1695},
|
||||
year={2023},
|
||||
organization={PMLR}
|
||||
}
|
||||
@article{zhou2023learning,
|
||||
title={Learning modular language-conditioned robot policies through attention},
|
||||
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Ben Amor, Heni and Stepputtis, Simon},
|
||||
journal={Autonomous Robots},
|
||||
pages={1--21},
|
||||
year={2023},
|
||||
publisher={Springer}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"austin_buds_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/BUDS-website/",
|
||||
"paper": "https://huggingface.co/papers/2109.13841",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{zhu2022bottom,
|
||||
title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation},
|
||||
author={Zhu, Yifeng and Stone, Peter and Zhu, Yuke},
|
||||
journal={IEEE Robotics and Automation Letters},
|
||||
volume={7},
|
||||
number={2},
|
||||
pages={4126--4133},
|
||||
year={2022},
|
||||
publisher={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"austin_sailor_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/sailor/",
|
||||
"paper": "https://huggingface.co/papers/2210.11435",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{nasiriany2022sailor,
|
||||
title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning},
|
||||
author={Soroush Nasiriany and Tian Gao and Ajay Mandlekar and Yuke Zhu},
|
||||
booktitle={Conference on Robot Learning (CoRL)},
|
||||
year={2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"austin_sirius_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/sirius/",
|
||||
"paper": "https://huggingface.co/papers/2211.08416",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{liu2022robot,
|
||||
title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment},
|
||||
author = {Huihan Liu and Soroush Nasiriany and Lance Zhang and Zhiyao Bao and Yuke Zhu},
|
||||
booktitle = {Robotics: Science and Systems (RSS)},
|
||||
year = {2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_autolab_ur5": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://sites.google.com/view/berkeley-ur5/home",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{BerkeleyUR5Website,
|
||||
title = {Berkeley {UR5} Demonstration Dataset},
|
||||
author = {Lawrence Yunliang Chen and Simeon Adebola and Ken Goldberg},
|
||||
howpublished = {https://sites.google.com/view/berkeley-ur5/home},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_cable_routing": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://sites.google.com/view/cablerouting/home",
|
||||
"paper": "https://huggingface.co/papers/2307.08927",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{luo2023multistage,
|
||||
author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine},
|
||||
title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning},
|
||||
journal = {arXiv pre-print},
|
||||
year = {2023},
|
||||
url = {https://huggingface.co/papers/2307.08927},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_fanuc_manipulation": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/berkeley.edu/fanuc-manipulation",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{fanuc_manipulation2023,
|
||||
title={Fanuc Manipulation: A Dataset for Learning-based Manipulation with FANUC Mate 200iD Robot},
|
||||
author={Zhu, Xinghao and Tian, Ran and Xu, Chenfeng and Ding, Mingyu and Zhan, Wei and Tomizuka, Masayoshi},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_gnm_cory_hall": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://huggingface.co/papers/1709.10489",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{kahn2018self,
|
||||
title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation},
|
||||
author={Kahn, Gregory and Villaflor, Adam and Ding, Bosen and Abbeel, Pieter and Levine, Sergey},
|
||||
booktitle={2018 IEEE international conference on robotics and automation (ICRA)},
|
||||
pages={5129--5136},
|
||||
year={2018},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_gnm_recon": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/recon-robot",
|
||||
"paper": "https://huggingface.co/papers/2104.05859",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{shah2021rapid,
|
||||
title={Rapid Exploration for Open-World Navigation with Latent Goal Models},
|
||||
author={Dhruv Shah and Benjamin Eysenbach and Nicholas Rhinehart and Sergey Levine},
|
||||
booktitle={5th Annual Conference on Robot Learning },
|
||||
year={2021},
|
||||
url={https://openreview.net/forum?id=d_SWJhyKfVw}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_gnm_sac_son": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/SACSoN-review",
|
||||
"paper": "https://huggingface.co/papers/2306.01874",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{hirose2023sacson,
|
||||
title={SACSoN: Scalable Autonomous Data Collection for Social Navigation},
|
||||
author={Hirose, Noriaki and Shah, Dhruv and Sridhar, Ajay and Levine, Sergey},
|
||||
journal={arXiv preprint arXiv:2306.01874},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_mvp": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://huggingface.co/papers/2203.06173",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@InProceedings{Radosavovic2022,
|
||||
title = {Real-World Robot Learning with Masked Visual Pre-training},
|
||||
author = {Ilija Radosavovic and Tete Xiao and Stephen James and Pieter Abbeel and Jitendra Malik and Trevor Darrell},
|
||||
booktitle = {CoRL},
|
||||
year = {2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_rpt": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://huggingface.co/papers/2306.10007",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{Radosavovic2023,
|
||||
title={Robot Learning with Sensorimotor Pre-training},
|
||||
author={Ilija Radosavovic and Baifeng Shi and Letian Fu and Ken Goldberg and Trevor Darrell and Jitendra Malik},
|
||||
year={2023},
|
||||
journal={arXiv:2306.10007}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"cmu_franka_exploration_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://human-world-model.github.io/",
|
||||
"paper": "https://huggingface.co/papers/2308.10901",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{mendonca2023structured,
|
||||
title={Structured World Models from Human Videos},
|
||||
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
|
||||
journal={RSS},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"cmu_play_fusion": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://play-fusion.github.io/",
|
||||
"paper": "https://huggingface.co/papers/2312.04549",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{chen2023playfusion,
|
||||
title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play},
|
||||
author={Chen, Lili and Bahl, Shikhar and Pathak, Deepak},
|
||||
booktitle={CoRL},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"cmu_stretch": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://robo-affordances.github.io/",
|
||||
"paper": "https://huggingface.co/papers/2304.08488",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{bahl2023affordances,
|
||||
title={Affordances from Human Videos as a Versatile Representation for Robotics},
|
||||
author={Bahl, Shikhar and Mendonca, Russell and Chen, Lili and Jain, Unnat and Pathak, Deepak},
|
||||
booktitle={CVPR},
|
||||
year={2023}
|
||||
}
|
||||
@article{mendonca2023structured,
|
||||
title={Structured World Models from Human Videos},
|
||||
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
|
||||
journal={CoRL},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"columbia_cairlab_pusht_real": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://diffusion-policy.cs.columbia.edu/",
|
||||
"paper": "https://huggingface.co/papers/2303.04137",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{chi2023diffusionpolicy,
|
||||
title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
||||
author={Chi, Cheng and Feng, Siyuan and Du, Yilun and Xu, Zhenjia and Cousineau, Eric and Burchfiel, Benjamin and Song, Shuran},
|
||||
booktitle={Proceedings of Robotics: Science and Systems (RSS)},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"conq_hose_manipulation": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/conq-hose-manipulation-dataset/home",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{ConqHoseManipData,
|
||||
author={Peter Mitrano and Dmitry Berenson},
|
||||
title={Conq Hose Manipulation Dataset, v1.15.0},
|
||||
year={2024},
|
||||
howpublished={https://sites.google.com/view/conq-hose-manipulation-dataset}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"dlr_edan_shared_control": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://ieeexplore.ieee.org/document/9341156",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{vogel_edan_2020,
|
||||
title = {EDAN - an EMG-Controlled Daily Assistant to Help People with Physical Disabilities},
|
||||
language = {en},
|
||||
booktitle = {2020 {IEEE}/{RSJ} {International} {Conference} on {Intelligent} {Robots} and {Systems} ({IROS})},
|
||||
author = {Vogel, Jörn and Hagengruber, Annette and Iskandar, Maged and Quere, Gabriel and Leipscher, Ulrike and Bustamante, Samuel and Dietrich, Alexander and Hoeppner, Hannes and Leidner, Daniel and Albu-Schäffer, Alin},
|
||||
year = {2020}
|
||||
}
|
||||
@inproceedings{quere_shared_2020,
|
||||
address = {Paris, France},
|
||||
title = {Shared {Control} {Templates} for {Assistive} {Robotics}},
|
||||
language = {en},
|
||||
booktitle = {2020 {IEEE} {International} {Conference} on {Robotics} and {Automation} ({ICRA})},
|
||||
author = {Quere, Gabriel and Hagengruber, Annette and Iskandar, Maged and Bustamante, Samuel and Leidner, Daniel and Stulp, Freek and Vogel, Joern},
|
||||
year = {2020},
|
||||
pages = {7},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"dlr_sara_grid_clamp": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://www.researchsquare.com/article/rs-3289569/v1",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{padalkar2023guided,
|
||||
title={A guided reinforcement learning approach using shared control templates for learning manipulation skills in the real world},
|
||||
author={Padalkar, Abhishek and Quere, Gabriel and Raffin, Antonin and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
|
||||
journal={Research square preprint rs-3289569/v1},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"dlr_sara_pour": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://elib.dlr.de/193739/1/padalkar2023rlsct.pdf",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{padalkar2023guiding,
|
||||
title={Guiding Reinforcement Learning with Shared Control Templates},
|
||||
author={Padalkar, Abhishek and Quere, Gabriel and Steinmetz, Franz and Raffin, Antonin and Nieuwenhuisen, Matthias and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
|
||||
booktitle={40th IEEE International Conference on Robotics and Automation, ICRA 2023},
|
||||
year={2023},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"droid_100": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://droid-dataset.github.io/",
|
||||
"paper": "https://huggingface.co/papers/2403.12945",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{khazatsky2024droid,
|
||||
title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset},
|
||||
author = {Alexander Khazatsky and Karl Pertsch and Suraj Nair and Ashwin Balakrishna and Sudeep Dasari and Siddharth Karamcheti and Soroush Nasiriany and Mohan Kumar Srirama and Lawrence Yunliang Chen and Kirsty Ellis and Peter David Fagan and Joey Hejna and Masha Itkina and Marion Lepert and Yecheng Jason Ma and Patrick Tree Miller and Jimmy Wu and Suneel Belkhale and Shivin Dass and Huy Ha and Arhan Jain and Abraham Lee and Youngwoon Lee and Marius Memmel and Sungjae Park and Ilija Radosavovic and Kaiyuan Wang and Albert Zhan and Kevin Black and Cheng Chi and Kyle Beltran Hatch and Shan Lin and Jingpei Lu and Jean Mercat and Abdul Rehman and Pannag R Sanketi and Archit Sharma and Cody Simpson and Quan Vuong and Homer Rich Walke and Blake Wulfe and Ted Xiao and Jonathan Heewon Yang and Arefeh Yavary and Tony Z. Zhao and Christopher Agia and Rohan Baijal and Mateo Guaman Castro and Daphne Chen and Qiuyu Chen and Trinity Chung and Jaimyn Drake and Ethan Paul Foster and Jensen Gao and David Antonio Herrera and Minho Heo and Kyle Hsu and Jiaheng Hu and Donovon Jackson and Charlotte Le and Yunshuang Li and Kevin Lin and Roy Lin and Zehan Ma and Abhiram Maddukuri and Suvir Mirchandani and Daniel Morton and Tony Nguyen and Abigail O'Neill and Rosario Scalise and Derick Seale and Victor Son and Stephen Tian and Emi Tran and Andrew E. Wang and Yilin Wu and Annie Xie and Jingyun Yang and Patrick Yin and Yunchu Zhang and Osbert Bastani and Glen Berseth and Jeannette Bohg and Ken Goldberg and Abhinav Gupta and Abhishek Gupta and Dinesh Jayaraman and Joseph J Lim and Jitendra Malik and Roberto Martín-Martín and Subramanian Ramamoorthy and Dorsa Sadigh and Shuran Song and Jiajun Wu and Michael C. Yip and Yuke Zhu and Thomas Kollar and Sergey Levine and Chelsea Finn},
|
||||
year = {2024},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"fmb": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://functional-manipulation-benchmark.github.io/",
|
||||
"paper": "https://huggingface.co/papers/2401.08553",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{luo2024fmb,
|
||||
title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning},
|
||||
author={Luo, Jianlan and Xu, Charles and Liu, Fangchen and Tan, Liam and Lin, Zipeng and Wu, Jeffrey and Abbeel, Pieter and Levine, Sergey},
|
||||
journal={arXiv preprint arXiv:2401.08553},
|
||||
year={2024}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"iamlab_cmu_pickup_insert": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://openreview.net/forum?id=WuBv9-IGDUA",
|
||||
"paper": "https://huggingface.co/papers/2401.14502",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{saxena2023multiresolution,
|
||||
title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models},
|
||||
author={Saumya Saxena and Mohit Sharma and Oliver Kroemer},
|
||||
booktitle={7th Annual Conference on Robot Learning},
|
||||
year={2023},
|
||||
url={https://openreview.net/forum?id=WuBv9-IGDUA}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"imperialcollege_sawyer_wrist_cam": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
},
|
||||
"jaco_play": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://github.com/clvrai/clvr_jaco_play_dataset",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@software{dass2023jacoplay,
|
||||
author = {Dass, Shivin and Yapeter, Jullian and Zhang, Jesse and Zhang, Jiahui
|
||||
and Pertsch, Karl and Nikolaidis, Stefanos and Lim, Joseph J.},
|
||||
title = {CLVR Jaco Play Dataset},
|
||||
url = {https://github.com/clvrai/clvr_jaco_play_dataset},
|
||||
version = {1.0.0},
|
||||
year = {2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"kaist_nonprehensile": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://github.com/JaeHyung-Kim/rlds_dataset_builder",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{kimpre,
|
||||
title={Pre-and post-contact policy decomposition for non-prehensile manipulation with zero-shot sim-to-real transfer},
|
||||
author={Kim, Minchan and Han, Junhyek and Kim, Jaehyung and Kim, Beomjoon},
|
||||
booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
|
||||
year={2023},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"nyu_door_opening_surprising_effectiveness": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://jyopari.github.io/VINN/",
|
||||
"paper": "https://huggingface.co/papers/2112.01511",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{pari2021surprising,
|
||||
title={The Surprising Effectiveness of Representation Learning for Visual Imitation},
|
||||
author={Jyothish Pari and Nur Muhammad Shafiullah and Sridhar Pandian Arunachalam and Lerrel Pinto},
|
||||
year={2021},
|
||||
eprint={2112.01511},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"nyu_franka_play_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://play-to-policy.github.io/",
|
||||
"paper": "https://huggingface.co/papers/2210.10047",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{cui2022play,
|
||||
title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data},
|
||||
author = {Cui, Zichen Jeff and Wang, Yibin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
|
||||
journal = {arXiv preprint arXiv:2210.10047},
|
||||
year = {2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"nyu_rot_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://rot-robot.github.io/",
|
||||
"paper": "https://huggingface.co/papers/2206.15469",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{haldar2023watch,
|
||||
title={Watch and match: Supercharging imitation with regularized optimal transport},
|
||||
author={Haldar, Siddhant and Mathur, Vaibhav and Yarats, Denis and Pinto, Lerrel},
|
||||
booktitle={Conference on Robot Learning},
|
||||
pages={32--43},
|
||||
year={2023},
|
||||
organization={PMLR}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"roboturk": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://roboturk.stanford.edu/dataset_real.html",
|
||||
"paper": "PAPER",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{mandlekar2019scaling,
|
||||
title={Scaling robot supervision to hundreds of hours with roboturk: Robotic manipulation dataset through human reasoning and dexterity},
|
||||
author={Mandlekar, Ajay and Booher, Jonathan and Spero, Max and Tung, Albert and Gupta, Anchit and Zhu, Yuke and Garg, Animesh and Savarese, Silvio and Fei-Fei, Li},
|
||||
booktitle={2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
|
||||
pages={1048--1055},
|
||||
year={2019},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"stanford_hydra_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/hydra-il-2023",
|
||||
"paper": "https://huggingface.co/papers/2306.17237",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{belkhale2023hydra,
|
||||
title={HYDRA: Hybrid Robot Actions for Imitation Learning},
|
||||
author={Belkhale, Suneel and Cui, Yuchen and Sadigh, Dorsa},
|
||||
journal={arxiv},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"stanford_kuka_multimodal_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/visionandtouch",
|
||||
"paper": "https://huggingface.co/papers/1810.10191",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{lee2019icra,
|
||||
title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks},
|
||||
author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette},
|
||||
booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)},
|
||||
year={2019},
|
||||
url={https://huggingface.co/papers/1810.10191}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"stanford_robocook": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://hshi74.github.io/robocook/",
|
||||
"paper": "https://huggingface.co/papers/2306.14447",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{shi2023robocook,
|
||||
title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools},
|
||||
author={Shi, Haochen and Xu, Huazhe and Clarke, Samuel and Li, Yunzhu and Wu, Jiajun},
|
||||
journal={arXiv preprint arXiv:2306.14447},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"taco_play": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://www.kaggle.com/datasets/oiermees/taco-robot",
|
||||
"paper": "https://huggingface.co/papers/2209.08959, https://huggingface.co/papers/2210.01911",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{rosete2022tacorl,
|
||||
author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard},
|
||||
title = {Latent Plans for Task Agnostic Offline Reinforcement Learning},
|
||||
journal = {Proceedings of the 6th Conference on Robot Learning (CoRL)},
|
||||
year = {2022}
|
||||
}
|
||||
@inproceedings{mees23hulc2,
|
||||
title={Grounding Language with Visual Affordances over Unstructured Data},
|
||||
author={Oier Mees and Jessica Borja-Diaz and Wolfram Burgard},
|
||||
booktitle = {Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
|
||||
year={2023},
|
||||
address = {London, UK}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"tokyo_u_lsmo": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "URL",
|
||||
"paper": "https://huggingface.co/papers/2107.05842",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@Article{Osa22,
|
||||
author = {Takayuki Osa},
|
||||
journal = {The International Journal of Robotics Research},
|
||||
title = {Motion Planning by Learning the Solution Manifold in Trajectory Optimization},
|
||||
year = {2022},
|
||||
number = {3},
|
||||
pages = {291--311},
|
||||
volume = {41},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"toto": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://toto-benchmark.org/",
|
||||
"paper": "https://huggingface.co/papers/2306.00942",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{zhou2023train,
|
||||
author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav},
|
||||
booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)},
|
||||
title={Train Offline, Test Online: A Real Robot Learning Benchmark},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"ucsd_kitchen_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@ARTICLE{ucsd_kitchens,
|
||||
author = {Ge Yan, Kris Wu, and Xiaolong Wang},
|
||||
title = {{ucsd kitchens Dataset}},
|
||||
year = {2023},
|
||||
month = {August}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"ucsd_pick_and_place_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://owmcorl.github.io/#",
|
||||
"paper": "https://huggingface.co/papers/2310.16029",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@preprint{Feng2023Finetuning,
|
||||
title={Finetuning Offline World Models in the Real World},
|
||||
author={Yunhai Feng, Nicklas Hansen, Ziyan Xiong, Chandramouli Rajagopalan, Xiaolong Wang},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"uiuc_d3field": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://robopil.github.io/d3fields/",
|
||||
"paper": "https://huggingface.co/papers/2309.16118",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{wang2023d3field,
|
||||
title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation},
|
||||
author={Wang, Yixuan and Li, Zhuoran and Zhang, Mingtong and Driggs-Campbell, Katherine and Wu, Jiajun and Fei-Fei, Li and Li, Yunzhu},
|
||||
journal={arXiv preprint arXiv:},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"usc_cloth_sim": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://uscresl.github.io/dmfd/",
|
||||
"paper": "https://huggingface.co/papers/2207.10148",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{salhotra2022dmfd,
|
||||
author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.},
|
||||
journal={IEEE Robotics and Automation Letters},
|
||||
title={Learning Deformable Object Manipulation From Expert Demonstrations},
|
||||
year={2022},
|
||||
volume={7},
|
||||
number={4},
|
||||
pages={8775-8782},
|
||||
doi={10.1109/LRA.2022.3187843}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utaustin_mutex": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/MUTEX/",
|
||||
"paper": "https://huggingface.co/papers/2309.14320",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{shah2023mutex,
|
||||
title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications},
|
||||
author={Rutav Shah and Roberto Mart{\'\i}n-Mart{\'\i}n and Yuke Zhu},
|
||||
booktitle={7th Annual Conference on Robot Learning},
|
||||
year={2023},
|
||||
url={https://openreview.net/forum?id=PwqiqaaEzJ}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_pr2_opening_fridge": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{oh2023pr2utokyodatasets,
|
||||
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
|
||||
title={X-Embodiment U-Tokyo PR2 Datasets},
|
||||
year={2023},
|
||||
url={https://github.com/ojh6404/rlds_dataset_builder},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_pr2_tabletop_manipulation": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{oh2023pr2utokyodatasets,
|
||||
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
|
||||
title={X-Embodiment U-Tokyo PR2 Datasets},
|
||||
year={2023},
|
||||
url={https://github.com/ojh6404/rlds_dataset_builder},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_saytap": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://saytap.github.io/",
|
||||
"paper": "https://huggingface.co/papers/2306.07580",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{saytap2023,
|
||||
author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and
|
||||
Tatsuya Harada},
|
||||
title = {SayTap: Language to Quadrupedal Locomotion},
|
||||
eprint = {arXiv:2306.07580},
|
||||
url = {https://saytap.github.io},
|
||||
note = {https://saytap.github.io},
|
||||
year = {2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_xarm_bimanual": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{matsushima2023weblab,
|
||||
title={Weblab xArm Dataset},
|
||||
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_xarm_pick_and_place": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{matsushima2023weblab,
|
||||
title={Weblab xArm Dataset},
|
||||
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"viola": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/VIOLA/",
|
||||
"paper": "https://huggingface.co/papers/2210.11339",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{zhu2022viola,
|
||||
title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors},
|
||||
author={Zhu, Yifeng and Joshi, Abhishek and Stone, Peter and Zhu, Yuke},
|
||||
journal={6th Annual Conference on Robot Learning (CoRL)},
|
||||
year={2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
}
|
||||
# spellchecker:on
|
||||
|
||||
|
||||
def batch_convert():
|
||||
status = {}
|
||||
logfile = LOCAL_DIR / "conversion_log.txt"
|
||||
assert set(DATASETS) == {id_.split("/")[1] for id_ in available_datasets}
|
||||
for num, (name, kwargs) in enumerate(DATASETS.items()):
|
||||
repo_id = f"lerobot/{name}"
|
||||
print(f"\nConverting {repo_id} ({num}/{len(DATASETS)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
|
||||
status = f"{repo_id}: success."
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
continue
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_convert()
|
||||
@@ -1,687 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
|
||||
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
|
||||
|
||||
We support 3 different scenarios for these tasks (see instructions below):
|
||||
1. Single task dataset: all episodes of your dataset have the same single task.
|
||||
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
|
||||
one episode to the next.
|
||||
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
|
||||
|
||||
|
||||
Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
|
||||
'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
|
||||
recorded with. For now, only Aloha/Koch type robots are supported with this option.
|
||||
|
||||
|
||||
# 1. Single task dataset
|
||||
If your dataset contains a single task, you can simply provide it directly via the CLI with the
|
||||
'--single-task' option.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
|
||||
--repo-id lerobot/aloha_sim_insertion_human_image \
|
||||
--single-task "Insert the peg into the socket." \
|
||||
--robot-config lerobot/configs/robot/aloha.yaml \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
```bash
|
||||
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
|
||||
--repo-id aliberts/koch_tutorial \
|
||||
--single-task "Pick the Lego block and drop it in the box on the right." \
|
||||
--robot-config lerobot/configs/robot/koch.yaml \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
|
||||
# 2. Single task episodes
|
||||
If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
|
||||
|
||||
- If your dataset already contains a language instruction column in its parquet file, you can simply provide
|
||||
this column's name with the '--tasks-col' arg.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
|
||||
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
||||
--tasks-col "language_instruction" \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
|
||||
'--tasks-path' arg. This file should have the following structure where keys correspond to each
|
||||
episode_index in the dataset, and values are the language instruction for that episode.
|
||||
|
||||
Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"0": "Do something",
|
||||
"1": "Do something else",
|
||||
"2": "Do something",
|
||||
"3": "Go there",
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
# 3. Multi task episodes
|
||||
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
|
||||
parquet file, and you must provide this column's name with the '--tasks-col' arg.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
|
||||
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
||||
--tasks-col "language_instruction" \
|
||||
--local-dir data
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import filecmp
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
get_safe_version,
|
||||
load_json,
|
||||
unflatten_dict,
|
||||
write_json,
|
||||
write_jsonlines,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
VideoFrame, # noqa: F401
|
||||
get_image_pixel_channels,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.robots import RobotConfig
|
||||
|
||||
V16 = "v1.6"
|
||||
V20 = "v2.0"
|
||||
|
||||
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
|
||||
V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
|
||||
V1_INFO_PATH = "meta_data/info.json"
|
||||
V1_STATS_PATH = "meta_data/stats.safetensors"
|
||||
|
||||
|
||||
def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
|
||||
if robot_cfg.type in ["aloha", "koch"]:
|
||||
state_names = [
|
||||
f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
|
||||
for arm in robot_cfg.follower_arms
|
||||
for motor in robot_cfg.follower_arms[arm].motors
|
||||
]
|
||||
action_names = [
|
||||
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
|
||||
for arm in robot_cfg.leader_arms
|
||||
for motor in robot_cfg.leader_arms[arm].motors
|
||||
]
|
||||
# elif robot_cfg["robot_type"] == "stretch3": TODO
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
|
||||
)
|
||||
|
||||
return {
|
||||
"robot_type": robot_cfg.type,
|
||||
"names": {
|
||||
"observation.state": state_names,
|
||||
"observation.effort": state_names,
|
||||
"action": action_names,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
safetensor_path = v1_dir / V1_STATS_PATH
|
||||
stats = load_file(safetensor_path)
|
||||
serialized_stats = {key: value.tolist() for key, value in stats.items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
|
||||
json_path = v2_dir / STATS_PATH
|
||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(serialized_stats, f, indent=4)
|
||||
|
||||
# Sanity check
|
||||
with open(json_path) as f:
|
||||
stats_json = json.load(f)
|
||||
|
||||
stats_json = flatten_dict(stats_json)
|
||||
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
|
||||
for key in stats:
|
||||
torch.testing.assert_close(stats_json[key], stats[key])
|
||||
|
||||
|
||||
def get_features_from_hf_dataset(
|
||||
dataset: Dataset, robot_config: RobotConfig | None = None
|
||||
) -> dict[str, list]:
|
||||
robot_config = parse_robot_config(robot_config)
|
||||
features = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if isinstance(ft, datasets.Value):
|
||||
dtype = ft.dtype
|
||||
shape = (1,)
|
||||
names = None
|
||||
if isinstance(ft, datasets.Sequence):
|
||||
assert isinstance(ft.feature, datasets.Value)
|
||||
dtype = ft.feature.dtype
|
||||
shape = (ft.length,)
|
||||
motor_names = (
|
||||
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
||||
)
|
||||
assert len(motor_names) == shape[0]
|
||||
names = {"motors": motor_names}
|
||||
elif isinstance(ft, datasets.Image):
|
||||
dtype = "image"
|
||||
image = dataset[0][key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
shape = (image.height, image.width, channels)
|
||||
names = ["height", "width", "channels"]
|
||||
elif ft._type == "VideoFrame":
|
||||
dtype = "video"
|
||||
shape = None # Add shape later
|
||||
names = ["height", "width", "channels"]
|
||||
|
||||
features[key] = {
|
||||
"dtype": dtype,
|
||||
"shape": shape,
|
||||
"names": names,
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
||||
df = dataset.to_pandas()
|
||||
tasks = list(set(tasks_by_episodes.values()))
|
||||
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
|
||||
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
|
||||
|
||||
features = dataset.features
|
||||
features["task_index"] = datasets.Value(dtype="int64")
|
||||
dataset = Dataset.from_pandas(df, features=features, split="train")
|
||||
return dataset, tasks
|
||||
|
||||
|
||||
def add_task_index_from_tasks_col(
|
||||
dataset: Dataset, tasks_col: str
|
||||
) -> tuple[Dataset, dict[str, list[str]], list[str]]:
|
||||
df = dataset.to_pandas()
|
||||
|
||||
# HACK: This is to clean some of the instructions in our version of Open X datasets
|
||||
prefix_to_clean = "tf.Tensor(b'"
|
||||
suffix_to_clean = "', shape=(), dtype=string)"
|
||||
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
|
||||
|
||||
# Create task_index col
|
||||
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
|
||||
tasks = df[tasks_col].unique().tolist()
|
||||
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
|
||||
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
|
||||
|
||||
# Build the dataset back from df
|
||||
features = dataset.features
|
||||
features["task_index"] = datasets.Value(dtype="int64")
|
||||
dataset = Dataset.from_pandas(df, features=features, split="train")
|
||||
dataset = dataset.remove_columns(tasks_col)
|
||||
|
||||
return dataset, tasks, tasks_by_episode
|
||||
|
||||
|
||||
def split_parquet_by_episodes(
|
||||
dataset: Dataset,
|
||||
total_episodes: int,
|
||||
total_chunks: int,
|
||||
output_dir: Path,
|
||||
) -> list:
|
||||
table = dataset.data.table
|
||||
episode_lengths = []
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_idx
|
||||
)
|
||||
pq.write_table(ep_table, output_file)
|
||||
|
||||
return episode_lengths
|
||||
|
||||
|
||||
def move_videos(
|
||||
repo_id: str,
|
||||
video_keys: list[str],
|
||||
total_episodes: int,
|
||||
total_chunks: int,
|
||||
work_dir: Path,
|
||||
clean_gittatributes: Path,
|
||||
branch: str = "main",
|
||||
) -> None:
|
||||
"""
|
||||
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
|
||||
commands to fetch git lfs video files references to move them into subdirectories without having to
|
||||
actually download them.
|
||||
"""
|
||||
_lfs_clone(repo_id, work_dir, branch)
|
||||
|
||||
videos_moved = False
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
|
||||
if len(video_files) == 0:
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
|
||||
videos_moved = True # Videos have already been moved
|
||||
|
||||
assert len(video_files) == total_episodes * len(video_keys)
|
||||
|
||||
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
|
||||
|
||||
current_gittatributes = work_dir / ".gitattributes"
|
||||
if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
|
||||
fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
|
||||
|
||||
if lfs_untracked_videos:
|
||||
fix_lfs_video_files_tracking(work_dir, video_files)
|
||||
|
||||
if videos_moved:
|
||||
return
|
||||
|
||||
video_dirs = sorted(work_dir.glob("videos*/"))
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
for vid_key in video_keys:
|
||||
chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key
|
||||
)
|
||||
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
target_path = DEFAULT_VIDEO_PATH.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
||||
if len(video_dirs) == 1:
|
||||
video_path = video_dirs[0] / video_file
|
||||
else:
|
||||
for dir in video_dirs:
|
||||
if (dir / video_file).is_file():
|
||||
video_path = dir / video_file
|
||||
break
|
||||
|
||||
video_path.rename(work_dir / target_path)
|
||||
|
||||
commit_message = "Move video files into chunk subdirectories"
|
||||
subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
|
||||
"""
|
||||
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
|
||||
there's no other option than to download the actual files and reupload them with lfs tracking.
|
||||
"""
|
||||
for i in range(0, len(lfs_untracked_videos), 100):
|
||||
files = lfs_untracked_videos[i : i + 100]
|
||||
try:
|
||||
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("git rm --cached ERROR:")
|
||||
print(e.stderr)
|
||||
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
|
||||
|
||||
commit_message = "Track video files with git lfs"
|
||||
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
|
||||
shutil.copyfile(clean_gittatributes, current_gittatributes)
|
||||
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
|
||||
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
||||
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
||||
subprocess.run(
|
||||
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
|
||||
check=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
|
||||
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
|
||||
lfs_tracked_files = subprocess.run(
|
||||
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
|
||||
)
|
||||
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
||||
return [f for f in video_files if f not in lfs_tracked_files]
|
||||
|
||||
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||
# Assumes first episode
|
||||
video_files = [
|
||||
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
||||
for vid_key in video_keys
|
||||
]
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||
)
|
||||
videos_info_dict = {}
|
||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
|
||||
|
||||
return videos_info_dict
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
local_dir: Path,
|
||||
single_task: str | None = None,
|
||||
tasks_path: Path | None = None,
|
||||
tasks_col: Path | None = None,
|
||||
robot_config: RobotConfig | None = None,
|
||||
test_branch: str | None = None,
|
||||
**card_kwargs,
|
||||
):
|
||||
v1 = get_safe_version(repo_id, V16)
|
||||
v1x_dir = local_dir / V16 / repo_id
|
||||
v20_dir = local_dir / V20 / repo_id
|
||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||
v20_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
|
||||
)
|
||||
branch = "main"
|
||||
if test_branch:
|
||||
branch = test_branch
|
||||
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
||||
|
||||
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||
features = get_features_from_hf_dataset(dataset, robot_config)
|
||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||
|
||||
if single_task and "language_instruction" in dataset.column_names:
|
||||
logging.warning(
|
||||
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
|
||||
)
|
||||
single_task = None
|
||||
tasks_col = "language_instruction"
|
||||
|
||||
# Episodes & chunks
|
||||
episode_indices = sorted(dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
assert episode_indices == list(range(total_episodes))
|
||||
total_videos = total_episodes * len(video_keys)
|
||||
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
|
||||
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
|
||||
total_chunks += 1
|
||||
|
||||
# Tasks
|
||||
if single_task:
|
||||
tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_path:
|
||||
tasks_by_episodes = load_json(tasks_path)
|
||||
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_col:
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
features["task_index"] = {
|
||||
"dtype": "int64",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
# Videos
|
||||
if video_keys:
|
||||
assert metadata_v1.get("video", False)
|
||||
dataset = dataset.remove_columns(video_keys)
|
||||
clean_gitattr = Path(
|
||||
hub_api.hf_hub_download(
|
||||
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
||||
)
|
||||
).absolute()
|
||||
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
||||
move_videos(
|
||||
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
|
||||
)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||
for key in video_keys:
|
||||
features[key]["shape"] = (
|
||||
videos_info[key].pop("video.height"),
|
||||
videos_info[key].pop("video.width"),
|
||||
videos_info[key].pop("video.channels"),
|
||||
)
|
||||
features[key]["video_info"] = videos_info[key]
|
||||
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||
if "encoding" in metadata_v1:
|
||||
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||
else:
|
||||
assert metadata_v1.get("video", 0) == 0
|
||||
videos_info = None
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
||||
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config.type
|
||||
repo_tags = [robot_type]
|
||||
else:
|
||||
robot_type = "unknown"
|
||||
repo_tags = None
|
||||
|
||||
# Episodes
|
||||
episodes = [
|
||||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||
for ep_idx in episode_indices
|
||||
]
|
||||
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
||||
|
||||
# Assemble metadata v2.0
|
||||
metadata_v2_0 = {
|
||||
"codebase_version": V20,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": total_episodes,
|
||||
"total_frames": len(dataset),
|
||||
"total_tasks": len(tasks),
|
||||
"total_videos": total_videos,
|
||||
"total_chunks": total_chunks,
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": metadata_v1["fps"],
|
||||
"splits": {"train": f"0:{total_episodes}"},
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
|
||||
"features": features,
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="data",
|
||||
folder_path=v20_dir / "data",
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="meta",
|
||||
folder_path=v20_dir / "meta",
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
if not test_branch:
|
||||
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
||||
|
||||
|
||||
def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
||||
if robot_type == "aloha":
|
||||
raise NotImplementedError # TODO
|
||||
|
||||
elif robot_type == "koch_follower":
|
||||
from lerobot.robots.koch_follower import KochFollowerConfig
|
||||
|
||||
return KochFollowerConfig(**kwargs)
|
||||
elif robot_type == "so100_follower":
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
|
||||
return SO100FollowerConfig(**kwargs)
|
||||
elif robot_type == "stretch":
|
||||
from lerobot.robots.stretch3 import Stretch3RobotConfig
|
||||
|
||||
return Stretch3RobotConfig(**kwargs)
|
||||
elif robot_type == "lekiwi":
|
||||
from lerobot.robots.lekiwi import LeKiwiConfig
|
||||
|
||||
return LeKiwiConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
task_args = parser.add_mutually_exclusive_group(required=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--single-task",
|
||||
type=str,
|
||||
help="A short but accurate description of the single task performed in the dataset.",
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--tasks-col",
|
||||
type=str,
|
||||
help="The name of the column containing language instructions",
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--tasks-path",
|
||||
type=Path,
|
||||
help="The path to a .json file containing one language instruction for each episode_index",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--license",
|
||||
type=str,
|
||||
default="apache-2.0",
|
||||
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.local_dir:
|
||||
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
||||
|
||||
if args.robot is not None:
|
||||
robot_config = make_robot_config(args.robot)
|
||||
|
||||
del args.robot
|
||||
|
||||
convert_dataset(**vars(args), robot_config=robot_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,87 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import get_dataset_config_info
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import INFO_PATH, write_info
|
||||
from lerobot.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
hub_api = HfApi()
|
||||
|
||||
|
||||
def fix_dataset(repo_id: str) -> str:
|
||||
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
|
||||
return f"{repo_id}: skipped (not in {V20})."
|
||||
|
||||
dataset_info = get_dataset_config_info(repo_id, "default")
|
||||
with SuppressWarnings():
|
||||
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
|
||||
parquet_features = set(dataset_info.features)
|
||||
|
||||
diff_parquet_meta = parquet_features - meta_features
|
||||
diff_meta_parquet = meta_features - parquet_features
|
||||
|
||||
if diff_parquet_meta:
|
||||
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
|
||||
|
||||
if not diff_meta_parquet:
|
||||
return f"{repo_id}: skipped (no diff)"
|
||||
|
||||
if diff_meta_parquet:
|
||||
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
||||
assert diff_meta_parquet == {"language_instruction"}
|
||||
lerobot_metadata.features.pop("language_instruction")
|
||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||
commit_info = hub_api.upload_file(
|
||||
path_or_fileobj=lerobot_metadata.root / INFO_PATH,
|
||||
path_in_repo=INFO_PATH,
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
revision=V20,
|
||||
commit_message="Remove 'language_instruction'",
|
||||
create_pr=True,
|
||||
)
|
||||
return f"{repo_id}: success - PR: {commit_info.pr_url}"
|
||||
|
||||
|
||||
def batch_fix():
|
||||
status = {}
|
||||
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logfile = LOCAL_DIR / "fix_features_v20.txt"
|
||||
for num, repo_id in enumerate(available_datasets):
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
status = fix_dataset(repo_id)
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
logging.info(status)
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_fix()
|
||||
@@ -1,54 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
|
||||
def batch_convert():
|
||||
status = {}
|
||||
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logfile = LOCAL_DIR / "conversion_log_v21.txt"
|
||||
hub_api = HfApi()
|
||||
for num, repo_id in enumerate(available_datasets):
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
|
||||
status = f"{repo_id}: success (already in {V21})."
|
||||
else:
|
||||
convert_dataset(repo_id)
|
||||
status = f"{repo_id}: success."
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_convert()
|
||||
@@ -1,114 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It will:
|
||||
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \
|
||||
--repo-id=aliberts/koch_tutorial
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
|
||||
|
||||
class SuppressWarnings:
|
||||
def __enter__(self):
|
||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
logging.getLogger().setLevel(self.previous_level)
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
|
||||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to push your dataset. Defaults to the main branch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dataset(**vars(args))
|
||||
@@ -1,99 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||||
sampled_indices = sample_indices(ep_len)
|
||||
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||||
video_frames = dataset._query_videos(query_timestamps, episode_index)
|
||||
return video_frames[ft_key].numpy()
|
||||
|
||||
|
||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||
|
||||
ep_stats = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if ft["dtype"] == "video":
|
||||
# We sample only for videos
|
||||
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
|
||||
else:
|
||||
ep_ft_data = np.array(ep_data[key])
|
||||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
|
||||
|
||||
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||||
assert dataset.episodes is None
|
||||
print("Computing episodes stats")
|
||||
total_episodes = dataset.meta.total_episodes
|
||||
if num_workers > 0:
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
|
||||
for ep_idx in range(total_episodes)
|
||||
}
|
||||
for future in tqdm(as_completed(futures), total=total_episodes):
|
||||
future.result()
|
||||
else:
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
convert_episode_stats(dataset, ep_idx)
|
||||
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
|
||||
|
||||
def check_aggregate_stats(
|
||||
dataset: LeRobotDataset,
|
||||
reference_stats: dict[str, dict[str, np.ndarray]],
|
||||
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
|
||||
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
|
||||
):
|
||||
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
|
||||
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
|
||||
for key, ft in dataset.features.items():
|
||||
# These values might need some fine-tuning
|
||||
if ft["dtype"] == "video":
|
||||
# to account for image sub-sampling
|
||||
rtol, atol = video_rtol_atol
|
||||
else:
|
||||
rtol, atol = default_rtol_atol
|
||||
|
||||
for stat, val in agg_stats[key].items():
|
||||
if key in reference_stats and stat in reference_stats[key]:
|
||||
err_msg = f"feature='{key}' stats='{stat}'"
|
||||
np.testing.assert_allclose(
|
||||
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
||||
)
|
||||
@@ -0,0 +1,500 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.1 to
|
||||
3.0. It will:
|
||||
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v3.0".
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id=lerobot/pusht
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import jsonlines
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from requests import HTTPError
|
||||
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
LEGACY_EPISODES_PATH,
|
||||
LEGACY_EPISODES_STATS_PATH,
|
||||
LEGACY_TASKS_PATH,
|
||||
cast_stats_to_numpy,
|
||||
flatten_dict,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_parquet_num_frames,
|
||||
get_video_size_in_mb,
|
||||
load_info,
|
||||
update_chunk_file_indices,
|
||||
write_episodes,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
|
||||
V21 = "v2.1"
|
||||
|
||||
|
||||
"""
|
||||
-------------------------
|
||||
OLD
|
||||
data/chunk-000/episode_000000.parquet
|
||||
|
||||
NEW
|
||||
data/chunk-000/file_000.parquet
|
||||
-------------------------
|
||||
OLD
|
||||
videos/chunk-000/CAMERA/episode_000000.mp4
|
||||
|
||||
NEW
|
||||
videos/chunk-000/file_000.mp4
|
||||
-------------------------
|
||||
OLD
|
||||
episodes.jsonl
|
||||
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
|
||||
|
||||
NEW
|
||||
meta/episodes/chunk-000/episodes_000.parquet
|
||||
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
|
||||
-------------------------
|
||||
OLD
|
||||
tasks.jsonl
|
||||
{"task_index": 1, "task": "Put the blue block in the green bowl"}
|
||||
|
||||
NEW
|
||||
meta/tasks/chunk-000/file_000.parquet
|
||||
task_index | task
|
||||
-------------------------
|
||||
OLD
|
||||
episodes_stats.jsonl
|
||||
|
||||
NEW
|
||||
meta/episodes_stats/chunk-000/file_000.parquet
|
||||
episode_index | mean | std | min | max
|
||||
-------------------------
|
||||
UPDATE
|
||||
meta/info.json
|
||||
-------------------------
|
||||
"""
|
||||
|
||||
|
||||
def load_jsonlines(fpath: Path) -> list[Any]:
|
||||
with jsonlines.open(fpath, "r") as reader:
|
||||
return list(reader)
|
||||
|
||||
|
||||
def legacy_load_episodes(local_dir: Path) -> dict:
|
||||
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
|
||||
|
||||
def legacy_load_episodes_stats(local_dir: Path) -> dict:
|
||||
episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH)
|
||||
return {
|
||||
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
||||
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||
}
|
||||
|
||||
|
||||
def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
return tasks, task_to_task_index
|
||||
|
||||
|
||||
def convert_tasks(root, new_root):
|
||||
tasks, _ = legacy_load_tasks(root)
|
||||
task_indices = tasks.keys()
|
||||
task_strings = tasks.values()
|
||||
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
|
||||
write_tasks(df_tasks, new_root)
|
||||
|
||||
|
||||
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys):
|
||||
# TODO(rcadene): to save RAM use Dataset.from_parquet(file) and concatenate_datasets
|
||||
dataframes = [pd.read_parquet(file) for file in paths_to_cat]
|
||||
# Concatenate all DataFrames along rows
|
||||
concatenated_df = pd.concat(dataframes, ignore_index=True)
|
||||
|
||||
path = new_root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(image_keys) > 0:
|
||||
schema = pa.Schema.from_pandas(concatenated_df)
|
||||
features = Features.from_arrow_schema(schema)
|
||||
for key in image_keys:
|
||||
features[key] = Image()
|
||||
schema = features.arrow_schema
|
||||
else:
|
||||
schema = None
|
||||
|
||||
concatenated_df.to_parquet(path, index=False, schema=schema)
|
||||
|
||||
|
||||
def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
data_dir = root / "data"
|
||||
ep_paths = sorted(data_dir.glob("*/*.parquet"))
|
||||
|
||||
image_keys = get_image_keys(root)
|
||||
|
||||
ep_idx = 0
|
||||
chunk_idx = 0
|
||||
file_idx = 0
|
||||
size_in_mb = 0
|
||||
num_frames = 0
|
||||
paths_to_cat = []
|
||||
episodes_metadata = []
|
||||
for ep_path in ep_paths:
|
||||
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
||||
ep_num_frames = get_parquet_num_frames(ep_path)
|
||||
ep_metadata = {
|
||||
"episode_index": ep_idx,
|
||||
"data/chunk_index": chunk_idx,
|
||||
"data/file_index": file_idx,
|
||||
"dataset_from_index": num_frames,
|
||||
"dataset_to_index": num_frames + ep_num_frames,
|
||||
}
|
||||
size_in_mb += ep_size_in_mb
|
||||
num_frames += ep_num_frames
|
||||
episodes_metadata.append(ep_metadata)
|
||||
ep_idx += 1
|
||||
|
||||
if size_in_mb < data_file_size_in_mb:
|
||||
paths_to_cat.append(ep_path)
|
||||
continue
|
||||
|
||||
if paths_to_cat:
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = ep_size_in_mb
|
||||
num_frames = ep_num_frames
|
||||
paths_to_cat = [ep_path]
|
||||
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
# Write remaining data if any
|
||||
if paths_to_cat:
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def get_video_keys(root):
|
||||
info = load_info(root)
|
||||
features = info["features"]
|
||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||
return video_keys
|
||||
|
||||
|
||||
def get_image_keys(root):
|
||||
info = load_info(root)
|
||||
features = info["features"]
|
||||
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
|
||||
return image_keys
|
||||
|
||||
|
||||
def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int):
|
||||
video_keys = get_video_keys(root)
|
||||
if len(video_keys) == 0:
|
||||
return None
|
||||
|
||||
video_keys = sorted(video_keys)
|
||||
|
||||
eps_metadata_per_cam = []
|
||||
for camera in video_keys:
|
||||
eps_metadata = convert_videos_of_camera(root, new_root, camera, video_file_size_in_mb)
|
||||
eps_metadata_per_cam.append(eps_metadata)
|
||||
|
||||
num_eps_per_cam = [len(eps_cam_map) for eps_cam_map in eps_metadata_per_cam]
|
||||
if len(set(num_eps_per_cam)) != 1:
|
||||
raise ValueError(f"All cams dont have same number of episodes ({num_eps_per_cam}).")
|
||||
|
||||
episods_metadata = []
|
||||
num_cameras = len(video_keys)
|
||||
num_episodes = num_eps_per_cam[0]
|
||||
for ep_idx in range(num_episodes):
|
||||
# Sanity check
|
||||
ep_ids = [eps_metadata_per_cam[cam_idx][ep_idx]["episode_index"] for cam_idx in range(num_cameras)]
|
||||
ep_ids += [ep_idx]
|
||||
if len(set(ep_ids)) != 1:
|
||||
raise ValueError(f"All episode indices need to match ({ep_ids}).")
|
||||
|
||||
ep_dict = {}
|
||||
for cam_idx in range(num_cameras):
|
||||
ep_dict.update(eps_metadata_per_cam[cam_idx][ep_idx])
|
||||
episods_metadata.append(ep_dict)
|
||||
|
||||
return episods_metadata
|
||||
|
||||
|
||||
def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_file_size_in_mb: int):
|
||||
# Access old paths to mp4
|
||||
videos_dir = root / "videos"
|
||||
ep_paths = sorted(videos_dir.glob(f"*/{video_key}/*.mp4"))
|
||||
|
||||
ep_idx = 0
|
||||
chunk_idx = 0
|
||||
file_idx = 0
|
||||
size_in_mb = 0
|
||||
duration_in_s = 0.0
|
||||
paths_to_cat = []
|
||||
episodes_metadata = []
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
# Check if adding this episode would exceed the limit
|
||||
if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
|
||||
# Size limit would be exceeded, save current accumulation WITHOUT this episode
|
||||
concatenate_video_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
)
|
||||
|
||||
# Update episodes metadata for the file we just saved
|
||||
for i, _ in enumerate(paths_to_cat):
|
||||
past_ep_idx = ep_idx - len(paths_to_cat) + i
|
||||
episodes_metadata[past_ep_idx][f"videos/{video_key}/chunk_index"] = chunk_idx
|
||||
episodes_metadata[past_ep_idx][f"videos/{video_key}/file_index"] = file_idx
|
||||
|
||||
# Move to next file and start fresh with current episode
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
size_in_mb = 0
|
||||
duration_in_s = 0.0
|
||||
paths_to_cat = []
|
||||
|
||||
# Add current episode metadata
|
||||
ep_metadata = {
|
||||
"episode_index": ep_idx,
|
||||
f"videos/{video_key}/chunk_index": chunk_idx, # Will be updated when file is saved
|
||||
f"videos/{video_key}/file_index": file_idx, # Will be updated when file is saved
|
||||
f"videos/{video_key}/from_timestamp": duration_in_s,
|
||||
f"videos/{video_key}/to_timestamp": duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
episodes_metadata.append(ep_metadata)
|
||||
|
||||
# Add current episode to accumulation
|
||||
paths_to_cat.append(ep_path)
|
||||
size_in_mb += ep_size_in_mb
|
||||
duration_in_s += ep_duration_in_s
|
||||
ep_idx += 1
|
||||
|
||||
# Write remaining videos if any
|
||||
if paths_to_cat:
|
||||
concatenate_video_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
)
|
||||
|
||||
# Update episodes metadata for the final file
|
||||
for i, _ in enumerate(paths_to_cat):
|
||||
past_ep_idx = ep_idx - len(paths_to_cat) + i
|
||||
episodes_metadata[past_ep_idx][f"videos/{video_key}/chunk_index"] = chunk_idx
|
||||
episodes_metadata[past_ep_idx][f"videos/{video_key}/file_index"] = file_idx
|
||||
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
|
||||
):
|
||||
num_episodes = len(episodes_metadata)
|
||||
episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
|
||||
episodes_stats_vals = list(episodes_stats.values())
|
||||
episodes_stats_keys = list(episodes_stats.keys())
|
||||
|
||||
for i in range(num_episodes):
|
||||
ep_legacy_metadata = episodes_legacy_metadata_vals[i]
|
||||
ep_metadata = episodes_metadata[i]
|
||||
ep_stats = episodes_stats_vals[i]
|
||||
|
||||
ep_ids_set = {
|
||||
ep_legacy_metadata["episode_index"],
|
||||
ep_metadata["episode_index"],
|
||||
episodes_stats_keys[i],
|
||||
}
|
||||
|
||||
if episodes_videos is None:
|
||||
ep_video = {}
|
||||
else:
|
||||
ep_video = episodes_videos[i]
|
||||
ep_ids_set.add(ep_video["episode_index"])
|
||||
|
||||
if len(ep_ids_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
|
||||
|
||||
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
|
||||
ep_dict["meta/episodes/chunk_index"] = 0
|
||||
ep_dict["meta/episodes/file_index"] = 0
|
||||
yield ep_dict
|
||||
|
||||
|
||||
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
|
||||
episodes_legacy_metadata = legacy_load_episodes(root)
|
||||
episodes_stats = legacy_load_episodes_stats(root)
|
||||
|
||||
num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
|
||||
if episodes_video_metadata is not None:
|
||||
num_eps_set.add(len(episodes_video_metadata))
|
||||
|
||||
if len(num_eps_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
|
||||
|
||||
ds_episodes = Dataset.from_generator(
|
||||
lambda: generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
|
||||
)
|
||||
)
|
||||
write_episodes(ds_episodes, new_root)
|
||||
|
||||
stats = aggregate_stats(list(episodes_stats.values()))
|
||||
write_stats(stats, new_root)
|
||||
|
||||
|
||||
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
|
||||
info = load_info(root)
|
||||
info["codebase_version"] = "v3.0"
|
||||
del info["total_chunks"]
|
||||
del info["total_videos"]
|
||||
info["data_files_size_in_mb"] = data_file_size_in_mb
|
||||
info["video_files_size_in_mb"] = video_file_size_in_mb
|
||||
info["data_path"] = DEFAULT_DATA_PATH
|
||||
info["video_path"] = DEFAULT_VIDEO_PATH
|
||||
info["fps"] = float(info["fps"])
|
||||
for key in info["features"]:
|
||||
if info["features"][key]["dtype"] == "video":
|
||||
# already has fps in video_info
|
||||
continue
|
||||
info["features"][key]["fps"] = info["fps"]
|
||||
write_info(info, new_root)
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
branch: str | None = None,
|
||||
data_file_size_in_mb: int | None = None,
|
||||
video_file_size_in_mb: int | None = None,
|
||||
):
|
||||
root = HF_LEROBOT_HOME / repo_id
|
||||
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
|
||||
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
|
||||
|
||||
if data_file_size_in_mb is None:
|
||||
data_file_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_file_size_in_mb is None:
|
||||
video_file_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
|
||||
if old_root.is_dir() and root.is_dir():
|
||||
shutil.rmtree(str(root))
|
||||
shutil.move(str(old_root), str(root))
|
||||
|
||||
if new_root.is_dir():
|
||||
shutil.rmtree(new_root)
|
||||
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
repo_type="dataset",
|
||||
revision=V21,
|
||||
local_dir=root,
|
||||
)
|
||||
|
||||
convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb)
|
||||
convert_tasks(root, new_root)
|
||||
episodes_metadata = convert_data(root, new_root, data_file_size_in_mb)
|
||||
episodes_videos_metadata = convert_videos(root, new_root, video_file_size_in_mb)
|
||||
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
|
||||
|
||||
shutil.move(str(root), str(old_root))
|
||||
shutil.move(str(new_root), str(root))
|
||||
|
||||
hub_api = HfApi()
|
||||
try:
|
||||
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
except HTTPError as e:
|
||||
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
||||
pass
|
||||
hub_api.delete_files(
|
||||
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
|
||||
repo_id=repo_id,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
)
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
LeRobotDataset(repo_id).push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to push your dataset. Defaults to the main branch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-file-size-in-mb",
|
||||
type=int,
|
||||
default=None,
|
||||
help="File size in MB. Defaults to 100 for data and 500 for videos.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video-file-size-in-mb",
|
||||
type=int,
|
||||
default=None,
|
||||
help="File size in MB. Defaults to 100 for data and 500 for videos.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dataset(**vars(args))
|
||||
@@ -17,12 +17,15 @@ import glob
|
||||
import importlib
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import av
|
||||
import fsspec
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
@@ -168,15 +171,68 @@ def decode_video_frames_torchvision(
|
||||
return closest_frames
|
||||
|
||||
|
||||
class VideoDecoderCache:
|
||||
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict[str, tuple[Any, Any]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def get_decoder(self, video_path: str):
|
||||
"""Get a cached decoder or create a new one."""
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
raise ImportError("torchcodec is required but not available.")
|
||||
|
||||
video_path = str(video_path)
|
||||
|
||||
with self._lock:
|
||||
if video_path not in self._cache:
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
self._cache[video_path] = (decoder, file_handle)
|
||||
|
||||
return self._cache[video_path][0]
|
||||
|
||||
def clear(self):
|
||||
"""Clear the cache and close file handles."""
|
||||
with self._lock:
|
||||
for _, file_handle in self._cache.values():
|
||||
file_handle.close()
|
||||
self._cache.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the number of cached decoders."""
|
||||
with self._lock:
|
||||
return len(self._cache)
|
||||
|
||||
|
||||
class FrameTimestampError(ValueError):
|
||||
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
_default_decoder_cache = VideoDecoderCache()
|
||||
|
||||
|
||||
def decode_video_frames_torchcodec(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
device: str = "cpu",
|
||||
log_loaded_timestamps: bool = False,
|
||||
decoder_cache: VideoDecoderCache | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file.
|
||||
timestamps: List of timestamps to extract frames.
|
||||
tolerance_s: Allowed deviation in seconds for frame retrieval.
|
||||
log_loaded_timestamps: Whether to log loaded timestamps.
|
||||
decoder_cache: Optional decoder cache instance. Uses default if None.
|
||||
|
||||
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
||||
|
||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||
@@ -185,27 +241,24 @@ def decode_video_frames_torchcodec(
|
||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||
"""
|
||||
if decoder_cache is None:
|
||||
decoder_cache = _default_decoder_cache
|
||||
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
raise ImportError("torchcodec is required but not available.")
|
||||
# Use cached decoder instead of creating new one each time
|
||||
decoder = decoder_cache.get_decoder(str(video_path))
|
||||
|
||||
# initialize video decoder
|
||||
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
||||
loaded_frames = []
|
||||
loaded_ts = []
|
||||
loaded_frames = []
|
||||
|
||||
# get metadata for frame information
|
||||
metadata = decoder.metadata
|
||||
average_fps = metadata.average_fps
|
||||
|
||||
# convert timestamps to frame indices
|
||||
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||
|
||||
# retrieve frames based on indices
|
||||
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||
|
||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
|
||||
loaded_frames.append(frame)
|
||||
loaded_ts.append(pts.item())
|
||||
if log_loaded_timestamps:
|
||||
@@ -236,10 +289,14 @@ def decode_video_frames_torchcodec(
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
|
||||
# convert to float32 in [0,1] range (channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
# convert to float32 in [0,1] range
|
||||
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
||||
|
||||
if not len(timestamps) == len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
|
||||
)
|
||||
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
@@ -263,7 +320,11 @@ def encode_video_frames(
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
|
||||
video_path.parent.mkdir(parents=True, exist_ok=overwrite)
|
||||
if video_path.exists() and not overwrite:
|
||||
logging.warning(f"Video file already exists: {video_path}. Skipping encoding.")
|
||||
return
|
||||
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encoders/pixel formats incompatibility check
|
||||
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
||||
@@ -273,9 +334,9 @@ def encode_video_frames(
|
||||
pix_fmt = "yuv420p"
|
||||
|
||||
# Get input frames
|
||||
template = "frame_" + ("[0-9]" * 6) + ".png"
|
||||
template = "frame-" + ("[0-9]" * 6) + ".png"
|
||||
input_list = sorted(
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("_")[-1].split(".")[0])
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
||||
)
|
||||
|
||||
# Define video output frame size (assuming all input frames are the same size)
|
||||
@@ -300,7 +361,7 @@ def encode_video_frames(
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
# "While less efficient, it is generally preferable to modify logging with Python’s logging"
|
||||
# "While less efficient, it is generally preferable to modify logging with Python's logging"
|
||||
logging.getLogger("libav").setLevel(log_level)
|
||||
|
||||
# Create and open output file (overwrite by default)
|
||||
@@ -331,6 +392,89 @@ def encode_video_frames(
|
||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
||||
|
||||
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
||||
):
|
||||
"""
|
||||
Concatenate multiple video files into a single video file using pyav.
|
||||
|
||||
This function takes a list of video input file paths and concatenates them into a single
|
||||
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
||||
concatenation without re-encoding.
|
||||
|
||||
Args:
|
||||
input_video_paths: Ordered list of input video file paths to concatenate.
|
||||
output_video_path: Path to the output video file.
|
||||
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
||||
|
||||
Note:
|
||||
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
||||
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
|
||||
codec, resolution, and frame rate for proper concatenation.
|
||||
"""
|
||||
|
||||
output_video_path = Path(output_video_path)
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
||||
return
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(input_video_paths) == 0:
|
||||
raise FileNotFoundError("No input video paths provided.")
|
||||
|
||||
# Create a temporary .ffconcat file to list the input video paths
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
for input_path in input_video_paths:
|
||||
tmp_concatenate_file.write(f"file '{str(input_path)}'\n")
|
||||
tmp_concatenate_file.flush()
|
||||
tmp_concatenate_path = tmp_concatenate_file.name
|
||||
|
||||
# Create input and output containers
|
||||
input_container = av.open(
|
||||
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
|
||||
) # safe = 0 allows absolute paths as well as relative paths
|
||||
|
||||
tmp_output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
||||
output_container = av.open(
|
||||
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
|
||||
) # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||
|
||||
# Replicate input streams in output container
|
||||
stream_map = {}
|
||||
for input_stream in input_container.streams:
|
||||
if input_stream.type in ("video", "audio", "subtitle"): # only copy compatible streams
|
||||
stream_map[input_stream.index] = output_container.add_stream_from_template(
|
||||
template=input_stream, opaque=True
|
||||
)
|
||||
stream_map[
|
||||
input_stream.index
|
||||
].time_base = (
|
||||
input_stream.time_base
|
||||
) # set the time base to the input stream time base (missing in the codec context)
|
||||
|
||||
# Demux + remux packets (no re-encode)
|
||||
for packet in input_container.demux():
|
||||
# Skip packets from un-mapped streams
|
||||
if packet.stream.index not in stream_map:
|
||||
continue
|
||||
|
||||
# Skip demux flushing packets
|
||||
if packet.dts is None:
|
||||
continue
|
||||
|
||||
output_stream = stream_map[packet.stream.index]
|
||||
packet.stream = output_stream
|
||||
output_container.mux(packet)
|
||||
|
||||
input_container.close()
|
||||
output_container.close()
|
||||
shutil.move(tmp_output_video_path, output_video_path)
|
||||
Path(tmp_concatenate_path).unlink()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrame:
|
||||
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
|
||||
@@ -454,6 +598,28 @@ def get_image_pixel_channels(image: Image):
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
"""
|
||||
Get the duration of a video file in seconds using PyAV.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file.
|
||||
|
||||
Returns:
|
||||
Duration of the video in seconds.
|
||||
"""
|
||||
with av.open(str(video_path)) as container:
|
||||
# Get the first video stream
|
||||
video_stream = container.streams.video[0]
|
||||
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
|
||||
if video_stream.duration is not None:
|
||||
duration = float(video_stream.duration * video_stream.time_base)
|
||||
else:
|
||||
# Fallback to container duration if stream duration is not available
|
||||
duration = float(container.duration / av.time_base)
|
||||
return duration
|
||||
|
||||
|
||||
class VideoEncodingManager:
|
||||
"""
|
||||
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
|
||||
@@ -487,7 +653,7 @@ class VideoEncodingManager:
|
||||
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
|
||||
f"from episode {start_ep} to {end_ep - 1}"
|
||||
)
|
||||
self.dataset.batch_encode_videos(start_ep, end_ep)
|
||||
self.dataset._batch_save_episode_video(start_ep, end_ep)
|
||||
|
||||
# Clean up episode images if recording was interrupted
|
||||
if exc_type is not None:
|
||||
|
||||
+87
-58
@@ -161,71 +161,33 @@ class XarmEnv(EnvConfig):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImagePreprocessingConfig:
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
class VideoRecordConfig:
|
||||
"""Configuration for video recording in ManiSkill environments."""
|
||||
|
||||
enabled: bool = False
|
||||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardClassifierConfig:
|
||||
"""Configuration for reward classification."""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
success_threshold: float = 0.5
|
||||
success_reward: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class InverseKinematicsConfig:
|
||||
"""Configuration for inverse kinematics processing."""
|
||||
|
||||
urdf_path: str | None = None
|
||||
target_frame_name: str | None = None
|
||||
end_effector_bounds: dict[str, list[float]] | None = None
|
||||
end_effector_step_sizes: dict[str, float] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObservationConfig:
|
||||
"""Configuration for observation processing."""
|
||||
class EnvTransformConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
# ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
||||
control_mode: str = "gamepad"
|
||||
display_cameras: bool = False
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
display_cameras: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class GripperConfig:
|
||||
"""Configuration for gripper control and penalties."""
|
||||
|
||||
use_gripper: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetConfig:
|
||||
"""Configuration for environment reset behavior."""
|
||||
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Any | None = None
|
||||
reset_time_s: float = 5.0
|
||||
control_time_s: float = 20.0
|
||||
terminate_on_success: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILSerlProcessorConfig:
|
||||
"""Configuration for environment processing pipeline."""
|
||||
|
||||
control_mode: str = "gamepad"
|
||||
observation: ObservationConfig | None = None
|
||||
image_preprocessing: ImagePreprocessingConfig | None = None
|
||||
gripper: GripperConfig | None = None
|
||||
reset: ResetConfig | None = None
|
||||
inverse_kinematics: InverseKinematicsConfig | None = None
|
||||
reward_classifier: RewardClassifierConfig | None = None
|
||||
max_gripper_pos: float | None = 100.0
|
||||
use_gripper: bool = True
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@@ -235,10 +197,77 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
|
||||
robot: RobotConfig | None = None
|
||||
teleop: TeleoperatorConfig | None = None
|
||||
processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig)
|
||||
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
fps: int = 10
|
||||
name: str = "real_robot"
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
task: str | None = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: str | None = None
|
||||
reward_classifier_pretrained_path: str | None = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("hil")
|
||||
@dataclass
|
||||
class HILEnvConfig(EnvConfig):
|
||||
"""Configuration for the HIL environment."""
|
||||
|
||||
name: str = "PandaPickCube"
|
||||
task: str | None = "PandaPickCubeKeyboard-v0"
|
||||
use_viewer: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
use_gamepad: bool = True
|
||||
state_dim: int = 18
|
||||
action_dim: int = 4
|
||||
fps: int = 100
|
||||
episode_length: int = 100
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"observation.image": OBS_IMAGE,
|
||||
"observation.state": OBS_STATE,
|
||||
}
|
||||
)
|
||||
################# args from hilserlrobotenv
|
||||
reward_classifier_pretrained_path: str | None = None
|
||||
robot_config: RobotConfig | None = None
|
||||
teleop_config: TeleoperatorConfig | None = None
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: str | None = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
############################
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"use_viewer": self.use_viewer,
|
||||
"use_gamepad": self.use_gamepad,
|
||||
"gripper_penalty": self.gripper_penalty,
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -27,6 +27,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
elif env_type == "hil":
|
||||
return HILEnvConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
|
||||
@@ -15,17 +15,6 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0.processor_pi0 import Pi0NewLineProcessor
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"PI0Config",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
]
|
||||
|
||||
@@ -35,6 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
@@ -50,16 +51,27 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: ACTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = ACT(config)
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
@@ -125,19 +137,23 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
actions = self.model(batch)[0]
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_act_pre_post_processors(
|
||||
config: ACTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -35,6 +35,7 @@ from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
@@ -56,6 +57,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -68,6 +70,14 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
@@ -96,6 +106,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -124,6 +137,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
@@ -139,9 +153,11 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
return loss, None
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_diffusion_pre_post_processors(
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -14,13 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import torch
|
||||
from typing_extensions import Unpack
|
||||
from torch import nn
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
@@ -38,10 +34,9 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import PolicyProcessorPipeline, ProcessorKwargs
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
if name == "tdmpc":
|
||||
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
@@ -106,159 +101,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for the processor config."""
|
||||
|
||||
preprocessor_config_filename: str | None
|
||||
postprocessor_config_filename: str | None
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
preprocessor_kwargs: ProcessorKwargs | None
|
||||
postprocessor_kwargs: ProcessorKwargs | None
|
||||
|
||||
|
||||
def make_pre_post_processors(
|
||||
policy_cfg: PreTrainedConfig,
|
||||
pretrained_path: str | None = None,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""Make a processor instance for a given policy type.
|
||||
|
||||
This function creates the appropriate processor configuration based on the policy type.
|
||||
Each policy type has its own processor with specific preprocessing steps.
|
||||
|
||||
Args:
|
||||
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
|
||||
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
|
||||
the processor from this path instead of creating a new one.
|
||||
**kwargs: Additional keyword arguments passed to the processor creation.
|
||||
|
||||
Returns:
|
||||
Tuple of (input_processor, output_processor) for the policy.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the policy type doesn't have a processor implemented.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# Extract preprocessor and postprocessor kwargs
|
||||
preprocessor_kwargs = kwargs.get("preprocessor_kwargs", {})
|
||||
postprocessor_kwargs = kwargs.get("postprocessor_kwargs", {})
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=preprocessor_kwargs.get("to_transition"),
|
||||
to_output=preprocessor_kwargs.get("to_output"),
|
||||
),
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=postprocessor_kwargs.get("to_transition"),
|
||||
to_output=postprocessor_kwargs.get("to_output"),
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new processor based on policy type
|
||||
if isinstance(policy_cfg, TDMPCConfig):
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
|
||||
|
||||
processors = make_tdmpc_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, DiffusionConfig):
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
|
||||
|
||||
processors = make_diffusion_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, ACTConfig):
|
||||
from lerobot.policies.act.processor_act import make_act_pre_post_processors
|
||||
|
||||
processors = make_act_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VQBeTConfig):
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
|
||||
|
||||
processors = make_vqbet_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0Config):
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
|
||||
|
||||
processors = make_pi0_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0FASTConfig):
|
||||
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
|
||||
|
||||
processors = make_pi0fast_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
|
||||
processors = make_sac_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, RewardClassifierConfig):
|
||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||
|
||||
processors = make_classifier_processor(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SmolVLAConfig):
|
||||
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
|
||||
processors = make_smolvla_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
|
||||
return processors
|
||||
|
||||
|
||||
def make_policy(
|
||||
cfg: PreTrainedConfig,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
@@ -305,6 +147,7 @@ def make_policy(
|
||||
kwargs = {}
|
||||
if ds_meta is not None:
|
||||
features = dataset_to_policy_features(ds_meta.features)
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
else:
|
||||
if not cfg.pretrained_path:
|
||||
logging.warning(
|
||||
@@ -312,8 +155,6 @@ def make_policy(
|
||||
"rather than a dataset. Normalization modules inside the policy will have infinite values "
|
||||
"by default without stats from a dataset."
|
||||
)
|
||||
if env_cfg is None:
|
||||
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
@@ -330,7 +171,7 @@ def make_policy(
|
||||
policy = policy_cls(**kwargs)
|
||||
|
||||
policy.to(cfg.device)
|
||||
assert isinstance(policy, torch.nn.Module)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
|
||||
@@ -0,0 +1,420 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def create_stats_buffers(
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||
statistics.
|
||||
|
||||
Args: (see Normalize and Unnormalize)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
||||
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
assert isinstance(norm_mode, NormalizationMode)
|
||||
|
||||
shape = tuple(ft.shape)
|
||||
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
buffer = {}
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"mean": nn.Parameter(mean, requires_grad=False),
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
def _no_stats_error_str(name: str) -> str:
|
||||
return (
|
||||
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
||||
"pretrained model."
|
||||
)
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# TODO: Remove this shallow copy
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
"""
|
||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
||||
original range used by the environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||
# and remove the `Normalize` and `Unnormalize` classes.
|
||||
def _initialize_stats_buffers(
|
||||
module: nn.Module,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> None:
|
||||
"""Register statistics buffers (mean/std or min/max) on the given *module*.
|
||||
|
||||
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
|
||||
but is factored out so it can be reused by both classes and stay in sync.
|
||||
"""
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
shape: tuple[int, ...] = tuple(ft.shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# reduce spatial dimensions, keep channel dimension only
|
||||
c, *_ = shape
|
||||
shape = (c, 1, 1)
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
std = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||
mean_data = stats[key]["mean"]
|
||||
std_data = stats[key]["std"]
|
||||
if isinstance(mean_data, torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
mean = mean_data.clone().to(dtype=torch.float32)
|
||||
std = std_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_mean", mean)
|
||||
module.register_buffer(f"{prefix}_std", std)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||
min_data = stats[key]["min"]
|
||||
max_data = stats[key]["max"]
|
||||
if isinstance(min_data, torch.Tensor):
|
||||
min_val = min_data.clone().to(dtype=torch.float32)
|
||||
max_val = max_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_min", min_val)
|
||||
module.register_buffer(f"{prefix}_max", max_val)
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
|
||||
class NormalizeBuffer(nn.Module):
|
||||
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class UnnormalizeBuffer(nn.Module):
|
||||
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max_val - min_val) + min_val
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
@@ -56,15 +56,18 @@ from collections import deque
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.paligemma_with_expert import (
|
||||
PaliGemmaWithExpertConfig,
|
||||
PaliGemmaWithExpertModel,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
from lerobot.utils.utils import get_safe_dtype, init_logging
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
@@ -220,17 +223,28 @@ class PI0Policy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FlowMatching(config)
|
||||
|
||||
self.reset()
|
||||
@@ -239,6 +253,99 @@ class PI0Policy(PreTrainedPolicy):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
|
||||
"""
|
||||
Transform state dict keys to match expected model structure.
|
||||
|
||||
Transformations:
|
||||
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
|
||||
model.paligemma_with_expert.paligemma.lm_head
|
||||
- model.paligemma_with_expert.paligemma.language_model.model ->
|
||||
model.paligemma_with_expert.paligemma.model.language_model
|
||||
- model.paligemma_with_expert.paligemma.vision_tower ->
|
||||
model.paligemma_with_expert.paligemma.model.vision_tower
|
||||
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
|
||||
model.paligemma_with_expert.paligemma.model.multi_modal_projector
|
||||
|
||||
Also handles tied weights between lm_head.weight and
|
||||
embed_tokens.weight.
|
||||
"""
|
||||
import re
|
||||
|
||||
transformed_dict = {}
|
||||
|
||||
transformations = [
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
|
||||
".paligemma_with_expert.paligemma.lm_head",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
|
||||
".paligemma_with_expert.paligemma.model.language_model",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
|
||||
".paligemma_with_expert.paligemma.model.vision_tower",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
|
||||
".paligemma_with_expert.paligemma.model.multi_modal_projector",
|
||||
),
|
||||
]
|
||||
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
for pattern, replacement in transformations:
|
||||
new_key = pattern.sub(replacement, new_key)
|
||||
transformed_dict[new_key] = value
|
||||
|
||||
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
|
||||
lm_head_key = None
|
||||
embed_tokens_key = None
|
||||
|
||||
for key in transformed_dict:
|
||||
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
|
||||
lm_head_key = key
|
||||
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
|
||||
embed_tokens_key = key
|
||||
if lm_head_key and embed_tokens_key:
|
||||
break
|
||||
|
||||
if lm_head_key and not embed_tokens_key:
|
||||
embed_tokens_key = lm_head_key.replace(
|
||||
".lm_head.weight", ".model.language_model.embed_tokens.weight"
|
||||
)
|
||||
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
|
||||
elif embed_tokens_key and not lm_head_key:
|
||||
lm_head_key = embed_tokens_key.replace(
|
||||
".model.language_model.embed_tokens.weight", ".lm_head.weight"
|
||||
)
|
||||
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
|
||||
|
||||
return transformed_dict
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
|
||||
) -> "PI0Policy":
|
||||
"""Override to apply key transformations before loading."""
|
||||
from safetensors.torch import load_file
|
||||
|
||||
init_logging()
|
||||
# Load the state dict from file safely
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
|
||||
# Apply key transformations
|
||||
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
|
||||
|
||||
# Load the transformed state dict
|
||||
msg = model.load_state_dict(transformed_state_dict, strict=strict)
|
||||
|
||||
# Log message
|
||||
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
|
||||
return model
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@@ -270,13 +377,14 @@ class PI0Policy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
@@ -286,6 +394,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -300,10 +410,12 @@ class PI0Policy(PreTrainedPolicy):
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("action_is_pad")
|
||||
|
||||
@@ -370,6 +482,26 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_STATE].device
|
||||
tasks = batch["task"]
|
||||
|
||||
# PaliGemma prompt has to end with a new line
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
@@ -435,7 +567,7 @@ class PI0FlowMatching(nn.Module):
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config: PI0Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
|
||||
class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one.
|
||||
This is required for the PaliGemma tokenizer.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
new_complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
def make_pi0_pre_post_processors(
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -58,6 +58,7 @@ from transformers.cache_utils import HybridCache, StaticCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
@@ -145,6 +146,14 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FAST(config)
|
||||
|
||||
@@ -212,6 +221,8 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
@@ -224,6 +235,8 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -236,6 +249,8 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss_dict = self.model.forward(batch)
|
||||
return loss_dict["loss"], loss_dict
|
||||
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_pi0fast_pre_post_processors(
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -28,6 +28,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
from lerobot.policies.normalize import NormalizeBuffer
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
@@ -44,6 +45,7 @@ class SACPolicy(
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
@@ -51,6 +53,7 @@ class SACPolicy(
|
||||
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
self._init_normalization(dataset_stats)
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
@@ -85,7 +88,8 @@ class SACPolicy(
|
||||
|
||||
observations_features = None
|
||||
if self.shared_encoder and self.actor.encoder.has_images:
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch)
|
||||
# Cache and normalize image features
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
|
||||
@@ -387,12 +391,28 @@ class SACPolicy(
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _init_normalization(self, dataset_stats):
|
||||
"""Initialize input/output normalization modules."""
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = nn.Identity()
|
||||
if self.config.dataset_stats is not None:
|
||||
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
|
||||
self.normalize_inputs = NormalizeBuffer(
|
||||
self.config.input_features, self.config.normalization_mapping, params
|
||||
)
|
||||
stats = dataset_stats or params
|
||||
self.normalize_targets = NormalizeBuffer(
|
||||
self.config.output_features, self.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
self.encoder_critic
|
||||
if self.shared_encoder
|
||||
else SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
@@ -404,7 +424,9 @@ class SACPolicy(
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
@@ -412,7 +434,9 @@ class SACPolicy(
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
@@ -466,9 +490,10 @@ class SACPolicy(
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig) -> None:
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self._init_image_layers()
|
||||
self._init_state_layers()
|
||||
self._compute_output_dim()
|
||||
@@ -543,10 +568,11 @@ class SACObservationEncoder(nn.Module):
|
||||
def forward(
|
||||
self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False
|
||||
) -> Tensor:
|
||||
obs = self.input_normalization(obs)
|
||||
parts = []
|
||||
if self.has_images:
|
||||
if cache is None:
|
||||
cache = self.get_cached_image_features(obs)
|
||||
cache = self.get_cached_image_features(obs, normalize=False)
|
||||
parts.append(self._encode_images(cache, detach))
|
||||
if self.has_env:
|
||||
parts.append(self.env_encoder(obs["observation.environment_state"]))
|
||||
@@ -559,7 +585,7 @@ class SACObservationEncoder(nn.Module):
|
||||
"No parts to concatenate, you should have at least one image or environment state or state"
|
||||
)
|
||||
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
|
||||
"""Extract and optionally cache image features from observations.
|
||||
|
||||
This function processes image observations through the vision encoder once and returns
|
||||
@@ -571,17 +597,26 @@ class SACObservationEncoder(nn.Module):
|
||||
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
|
||||
- Caching these features can provide 2-4x speedup in training and inference
|
||||
|
||||
Normalization behavior:
|
||||
- When called from inside forward(): set normalize=False since inputs are already normalized
|
||||
- When called from outside forward(): set normalize=True to ensure proper input normalization
|
||||
|
||||
Usage patterns:
|
||||
- Called in select_action()
|
||||
- Called in select_action() with normalize=True
|
||||
- Called in learner.py's get_observation_features() to pre-compute features for all policy components
|
||||
- Called internally by forward()
|
||||
- Called internally by forward() with normalize=False
|
||||
|
||||
Args:
|
||||
obs: Dictionary of observation tensors containing image keys
|
||||
normalize: Whether to normalize observations before encoding
|
||||
Set to True when calling directly from outside the encoder's forward method
|
||||
Set to False when calling from within forward() where inputs are already normalized
|
||||
|
||||
Returns:
|
||||
Dictionary mapping image keys to their corresponding encoded features
|
||||
"""
|
||||
if normalize:
|
||||
obs = self.input_normalization(obs)
|
||||
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
|
||||
out = self.image_encoder(batched)
|
||||
chunks = torch.chunk(out, len(self.image_keys), dim=0)
|
||||
@@ -712,6 +747,7 @@ class CriticEnsemble(nn.Module):
|
||||
Args:
|
||||
encoder (SACObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
output_normalization (nn.Module): normalization layer for actions.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
@@ -721,11 +757,13 @@ class CriticEnsemble(nn.Module):
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
output_normalization: nn.Module,
|
||||
init_final: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.output_normalization = output_normalization
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
def forward(
|
||||
@@ -737,6 +775,11 @@ class CriticEnsemble(nn.Module):
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
# NOTE: We normalize actions it helps for sample efficiency
|
||||
actions: dict[str, torch.tensor] = {"action": actions}
|
||||
# NOTE: Normalization layer took dict in input and outputs a dict that why
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_sac_pre_post_processors(
|
||||
config: SACConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -20,6 +20,7 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import OBS_IMAGE, REWARD
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
@@ -107,12 +108,22 @@ class Classifier(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Initialize normalization (standardized with the policy framework)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Set up encoder
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
@@ -236,6 +247,10 @@ class Classifier(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Normalize inputs if needed
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images and labels
|
||||
images, labels = self.extract_images_and_labels(batch)
|
||||
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessorStep,
|
||||
IdentityProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
)
|
||||
|
||||
|
||||
def make_classifier_processor(
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
NormalizerProcessorStep(
|
||||
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
NormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [DeviceProcessorStep(device="cpu"), IdentityProcessorStep()]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name="classifier_preprocessor",
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name="classifier_postprocessor",
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -53,13 +53,21 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections import deque
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import (
|
||||
Normalize,
|
||||
Unnormalize,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
@@ -68,6 +76,102 @@ from lerobot.policies.utils import (
|
||||
)
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
||||
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||
|
||||
|
||||
def canonicalise(k: str) -> str:
|
||||
"""
|
||||
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
|
||||
normalisation-buffer key.
|
||||
"""
|
||||
return _VARIANT_RE.sub(".buffer_", k)
|
||||
|
||||
|
||||
def standardise_state_dict(
|
||||
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
|
||||
) -> tuple[dict[str, torch.Tensor], list[str]]:
|
||||
"""
|
||||
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
|
||||
• If several variant keys collapse to the same canonical name we keep the
|
||||
first one and log the collision.
|
||||
• Returns the new dict + a list of entries that could not be matched.
|
||||
"""
|
||||
out, collisions, unmatched = {}, {}, []
|
||||
|
||||
for k, v in checkpoint.items():
|
||||
canon = canonicalise(k)
|
||||
if canon in ref_keys:
|
||||
if canon in out: # duplicate after collapsing
|
||||
collisions.setdefault(canon, []).append(k)
|
||||
else:
|
||||
out[canon] = v
|
||||
else:
|
||||
unmatched.append(k)
|
||||
|
||||
if verbose:
|
||||
for canon, variants in collisions.items():
|
||||
print(f"[standardise_state_dict] '{canon}' ← {variants}")
|
||||
if unmatched:
|
||||
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
|
||||
|
||||
out.update({k: checkpoint[k] for k in unmatched})
|
||||
return out, unmatched
|
||||
|
||||
|
||||
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
||||
"""
|
||||
Renames keys in a checkpoint dictionary based on the given rename string.
|
||||
|
||||
Args:
|
||||
checkpoint (dict): The checkpoint dictionary.
|
||||
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
|
||||
|
||||
Returns:
|
||||
dict: The modified checkpoint with renamed keys.
|
||||
"""
|
||||
|
||||
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
|
||||
|
||||
new_checkpoint = {}
|
||||
for k, v in checkpoint.items():
|
||||
for old_key, new_key in rename_dict.items():
|
||||
if old_key in k:
|
||||
k = k.replace(old_key, new_key)
|
||||
new_checkpoint[k] = v
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def load_smolvla(
|
||||
model: torch.nn.Module,
|
||||
filename: str | os.PathLike,
|
||||
*,
|
||||
device: str = "cpu",
|
||||
checkpoint_keys_mapping: str = "",
|
||||
) -> torch.nn.Module:
|
||||
state_dict = safetensors.torch.load_file(filename, device=device)
|
||||
|
||||
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
|
||||
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
|
||||
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
|
||||
|
||||
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
|
||||
|
||||
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
|
||||
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
|
||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
|
||||
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
|
||||
raise RuntimeError(
|
||||
"SmolVLA %d missing / %d unexpected keys",
|
||||
len(missing),
|
||||
len(unexpected),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
@@ -222,17 +326,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: SmolVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
|
||||
self.model = VLAFlowMatching(config)
|
||||
self.reset()
|
||||
|
||||
@@ -242,6 +357,23 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls,
|
||||
model: "SmolVLAPolicy",
|
||||
model_file: str,
|
||||
map_location: str,
|
||||
strict: bool,
|
||||
):
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
return load_smolvla(
|
||||
model,
|
||||
model_file,
|
||||
device=map_location,
|
||||
checkpoint_keys_mapping="model._orig_mod.//model.",
|
||||
)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@@ -257,8 +389,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
|
||||
|
||||
@@ -266,6 +397,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -275,6 +408,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -315,11 +450,11 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
@@ -383,6 +518,30 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
img_masks.append(mask)
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_STATE].device
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
if len(tasks) == 1:
|
||||
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
||||
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding=self.config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
ProcessorStepRegistry,
|
||||
RenameProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_smolvla_pre_post_processors(
|
||||
config: SmolVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
SmolVLANewLineProcessor(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=config.vlm_model_name,
|
||||
padding=config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=config.tokenizer_max_length,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
|
||||
class SmolVLANewLineProcessor(ComplementaryDataProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one."""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
new_complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -36,6 +36,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
@@ -62,19 +63,26 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
config_class = TDMPCConfig
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TDMPCConfig,
|
||||
):
|
||||
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = TDMPCTOLD(config)
|
||||
self.model_target = deepcopy(self.model)
|
||||
for param in self.model_target.parameters():
|
||||
@@ -129,6 +137,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -138,12 +147,11 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -312,9 +320,11 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_tdmpc_pre_post_processors(
|
||||
config: TDMPCConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -28,6 +28,7 @@ import torchvision
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
@@ -47,6 +48,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: VQBeTConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -59,6 +61,14 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
self.reset()
|
||||
@@ -118,6 +128,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -131,12 +142,10 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
# NOTE: It's important that this happens after stacking the images into a single key.
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -156,8 +165,10 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
# loss: total loss of training RVQ
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
|
||||
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_vqbet_pre_post_processors(
|
||||
config: VQBeTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}), # Let the possibility to the user to rename the keys
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -14,90 +14,41 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .converters import (
|
||||
batch_to_transition,
|
||||
create_transition,
|
||||
merge_transitions,
|
||||
transition_to_batch,
|
||||
transition_to_dataset_frame,
|
||||
)
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .gym_action_processor import Numpy2TorchActionProcessorStep, Torch2NumpyActionProcessorStep
|
||||
from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
GripperPenaltyProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
TimeLimitProcessorStep,
|
||||
)
|
||||
from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep
|
||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
|
||||
from .observation_processor import VanillaObservationProcessorStep
|
||||
from .device_processor import DeviceProcessor
|
||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||
from .observation_processor import VanillaObservationProcessor
|
||||
from .pipeline import (
|
||||
ActionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DoneProcessorStep,
|
||||
IdentityProcessorStep,
|
||||
InfoProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
ActionProcessor,
|
||||
DoneProcessor,
|
||||
EnvTransition,
|
||||
IdentityProcessor,
|
||||
InfoProcessor,
|
||||
ObservationProcessor,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RewardProcessorStep,
|
||||
RobotProcessorPipeline,
|
||||
TruncatedProcessorStep,
|
||||
RewardProcessor,
|
||||
RobotProcessor,
|
||||
TransitionKey,
|
||||
TruncatedProcessor,
|
||||
)
|
||||
from .rename_processor import RenameProcessorStep
|
||||
from .tokenizer_processor import TokenizerProcessorStep
|
||||
from .rename_processor import RenameProcessor
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessorStep",
|
||||
"AddTeleopActionAsComplimentaryDataStep",
|
||||
"AddTeleopEventsAsInfoStep",
|
||||
"ComplementaryDataProcessorStep",
|
||||
"batch_to_transition",
|
||||
"create_transition",
|
||||
"DeviceProcessorStep",
|
||||
"DoneProcessorStep",
|
||||
"ActionProcessor",
|
||||
"DeviceProcessor",
|
||||
"DoneProcessor",
|
||||
"EnvTransition",
|
||||
"GripperPenaltyProcessorStep",
|
||||
"hotswap_stats",
|
||||
"IdentityProcessorStep",
|
||||
"ImageCropResizeProcessorStep",
|
||||
"InfoProcessorStep",
|
||||
"InterventionActionProcessorStep",
|
||||
"JointVelocityProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"merge_transitions",
|
||||
"MotorCurrentProcessorStep",
|
||||
"NormalizerProcessorStep",
|
||||
"Numpy2TorchActionProcessorStep",
|
||||
"ObservationProcessorStep",
|
||||
"PolicyProcessorPipeline",
|
||||
"ProcessorKwargs",
|
||||
"IdentityProcessor",
|
||||
"InfoProcessor",
|
||||
"NormalizerProcessor",
|
||||
"UnnormalizerProcessor",
|
||||
"ObservationProcessor",
|
||||
"ProcessorStep",
|
||||
"ProcessorStepRegistry",
|
||||
"RenameProcessorStep",
|
||||
"RewardClassifierProcessorStep",
|
||||
"RewardProcessorStep",
|
||||
"DataProcessorPipeline",
|
||||
"TimeLimitProcessorStep",
|
||||
"AddBatchDimensionProcessorStep",
|
||||
"RobotProcessorPipeline",
|
||||
"TokenizerProcessorStep",
|
||||
"Torch2NumpyActionProcessorStep",
|
||||
"transition_to_batch",
|
||||
"transition_to_dataset_frame",
|
||||
"RenameProcessor",
|
||||
"RewardProcessor",
|
||||
"RobotProcessor",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessorStep",
|
||||
"UnnormalizerProcessorStep",
|
||||
"VanillaObservationProcessorStep",
|
||||
"TruncatedProcessor",
|
||||
"VanillaObservationProcessor",
|
||||
]
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .core import EnvTransition
|
||||
from .pipeline import (
|
||||
ActionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_action")
|
||||
class AddBatchDimensionActionStep(ActionProcessorStep):
|
||||
"""Process action component in-place, adding batch dimension if needed."""
|
||||
|
||||
def action(self, action):
|
||||
if not isinstance(action, Tensor) or action.dim() != 1:
|
||||
return action
|
||||
|
||||
return action.unsqueeze(0)
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
|
||||
class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
||||
"""Process observation component in-place, adding batch dimensions where needed."""
|
||||
|
||||
def observation(self, observation):
|
||||
# Process state observations - add batch dim if 1D
|
||||
for state_key in [OBS_STATE, OBS_ENV_STATE]:
|
||||
if state_key in observation:
|
||||
state_value = observation[state_key]
|
||||
if isinstance(state_value, Tensor) and state_value.dim() == 1:
|
||||
observation[state_key] = state_value.unsqueeze(0)
|
||||
|
||||
# Process single image observation - add batch dim if 3D
|
||||
if OBS_IMAGE in observation:
|
||||
image_value = observation[OBS_IMAGE]
|
||||
if isinstance(image_value, Tensor) and image_value.dim() == 3:
|
||||
observation[OBS_IMAGE] = image_value.unsqueeze(0)
|
||||
|
||||
# Process multiple image observations - add batch dim if 3D
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
return observation
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
|
||||
class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
"""Process complementary data in-place, handling task field batching."""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
# Process task field - wrap string in list to add batch dimension
|
||||
if "task" in complementary_data:
|
||||
task_value = complementary_data["task"]
|
||||
if isinstance(task_value, str):
|
||||
complementary_data["task"] = [task_value]
|
||||
|
||||
# Process index field - add batch dim if 0D
|
||||
if "index" in complementary_data:
|
||||
index_value = complementary_data["index"]
|
||||
if isinstance(index_value, Tensor) and index_value.dim() == 0:
|
||||
complementary_data["index"] = index_value.unsqueeze(0)
|
||||
|
||||
# Process task_index field - add batch dim if 0D
|
||||
if "task_index" in complementary_data:
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
return complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor")
|
||||
class AddBatchDimensionProcessorStep(ProcessorStep):
|
||||
"""Processor that adds batch dimensions to observations and actions when needed.
|
||||
|
||||
This processor ensures that observations and actions have proper batch dimensions for model processing:
|
||||
|
||||
- For state observations (observation.state, observation.environment_state):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
|
||||
- For image observations (observation.image, observation.images.*):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
|
||||
|
||||
- For actions:
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
|
||||
- For task field in complementary data:
|
||||
Wraps string task in a list to add batch dimension
|
||||
(task must be a string or list of strings)
|
||||
|
||||
This is useful when processing single transitions that need to be batched for
|
||||
model inference or when converting from unbatched environment outputs to
|
||||
batched model inputs.
|
||||
|
||||
The processor only modifies tensors that need batching and leaves already
|
||||
batched tensors unchanged.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# State: (7,) -> (1, 7)
|
||||
# Image: (224, 224, 3) -> (1, 224, 224, 3)
|
||||
# Action: (4,) -> (1, 4)
|
||||
# Task: "pick_cube" -> ["pick_cube"]
|
||||
# Already batched: (1, 7) -> (1, 7) [unchanged]
|
||||
```
|
||||
"""
|
||||
|
||||
to_batch_action_processor: AddBatchDimensionActionStep = field(
|
||||
default_factory=AddBatchDimensionActionStep
|
||||
)
|
||||
to_batch_observation_processor: AddBatchDimensionObservationStep = field(
|
||||
default_factory=AddBatchDimensionObservationStep
|
||||
)
|
||||
to_batch_complementary_data_processor: AddBatchDimensionComplementaryDataStep = field(
|
||||
default_factory=AddBatchDimensionComplementaryDataStep
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = self.to_batch_action_processor(transition)
|
||||
transition = self.to_batch_observation_processor(transition)
|
||||
transition = self.to_batch_complementary_data_processor(transition)
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# NOTE: We ignore the batch dimension when transforming features
|
||||
return features
|
||||
@@ -1,481 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from functools import singledispatch
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
@singledispatch
|
||||
def to_tensor(
|
||||
value: Any,
|
||||
*,
|
||||
dtype: torch.dtype | None = torch.float32,
|
||||
device: torch.device | str | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert various data types to PyTorch tensors with configurable options.
|
||||
|
||||
This is a unified tensor conversion function using single dispatch to handle
|
||||
different input types appropriately.
|
||||
|
||||
Args:
|
||||
value: Input value to convert (tensor, array, scalar, sequence, etc.)
|
||||
dtype: Target tensor dtype. If None, preserves original dtype.
|
||||
device: Target device for the tensor.
|
||||
|
||||
Returns:
|
||||
PyTorch tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input type is not supported.
|
||||
"""
|
||||
raise TypeError(f"Unsupported type for tensor conversion: {type(value)}")
|
||||
|
||||
|
||||
@to_tensor.register(torch.Tensor)
|
||||
def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle existing PyTorch tensors."""
|
||||
if dtype is not None:
|
||||
value = value.to(dtype=dtype)
|
||||
if device is not None:
|
||||
value = value.to(device=device)
|
||||
return value
|
||||
|
||||
|
||||
@to_tensor.register(np.ndarray)
|
||||
def _(
|
||||
value: np.ndarray,
|
||||
*,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Handle numpy arrays."""
|
||||
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars
|
||||
if value.ndim == 0:
|
||||
# Numpy scalars should be converted to 0-dimensional tensors
|
||||
scalar_value = value.item()
|
||||
return torch.tensor(scalar_value, dtype=dtype, device=device)
|
||||
|
||||
# Create tensor from numpy array (torch.from_numpy handles contiguity automatically)
|
||||
tensor = torch.from_numpy(value)
|
||||
|
||||
# Apply dtype conversion if specified
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
if device is not None:
|
||||
tensor = tensor.to(device=device)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
@to_tensor.register(int)
|
||||
@to_tensor.register(float)
|
||||
@to_tensor.register(np.integer)
|
||||
@to_tensor.register(np.floating)
|
||||
def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle scalar values including numpy scalars."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(list)
|
||||
@to_tensor.register(tuple)
|
||||
def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle sequences (lists, tuples)."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(dict)
|
||||
def _(value: dict, *, device=None, **kwargs) -> dict:
|
||||
"""Handle dictionaries by recursively converting values to tensors."""
|
||||
if not value:
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
for key, sub_value in value.items():
|
||||
if sub_value is None:
|
||||
continue
|
||||
|
||||
if isinstance(sub_value, dict):
|
||||
# Recursively process nested dictionaries
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert individual values to tensors
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _from_tensor(x: torch.Tensor | Any) -> np.ndarray | float | int | Any:
|
||||
"""Convert tensor to numpy/scalar if needed."""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
|
||||
return x
|
||||
|
||||
|
||||
def _is_image(arr: Any) -> bool:
|
||||
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
|
||||
|
||||
|
||||
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
state, images = {}, {}
|
||||
for k, v in obs.items():
|
||||
if "image" in k.lower() or _is_image(v):
|
||||
images[k] = v
|
||||
else:
|
||||
state[k] = v
|
||||
return state, images
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Private Helper Functions (Common Logic)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract complementary data (pad flags, task, index, task_index)."""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
||||
|
||||
|
||||
def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransition:
|
||||
"""Merge two transitions, with other taking precedence."""
|
||||
out = deepcopy(base)
|
||||
|
||||
for key in (
|
||||
TransitionKey.OBSERVATION,
|
||||
TransitionKey.ACTION,
|
||||
TransitionKey.INFO,
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
):
|
||||
if other.get(key):
|
||||
out.setdefault(key, {}).update(deepcopy(other[key]))
|
||||
|
||||
for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED):
|
||||
if k in other:
|
||||
out[k] = other[k]
|
||||
return out
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Core Conversion Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation: dict[str, Any] | None = None,
|
||||
action: dict[str, Any] | None = None,
|
||||
reward: float = 0.0,
|
||||
done: bool = False,
|
||||
truncated: bool = False,
|
||||
info: dict[str, Any] | None = None,
|
||||
complementary_data: dict[str, Any] | None = None,
|
||||
) -> EnvTransition:
|
||||
"""Create an EnvTransition with sensible defaults.
|
||||
|
||||
Args:
|
||||
observation: Observation dictionary.
|
||||
action: Action dictionary.
|
||||
reward: Scalar reward value.
|
||||
done: Episode termination flag.
|
||||
truncated: Episode truncation flag.
|
||||
info: Additional info dictionary.
|
||||
complementary_data: Complementary data dictionary.
|
||||
|
||||
Returns:
|
||||
Complete EnvTransition dictionary.
|
||||
"""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info if info is not None else {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
|
||||
}
|
||||
|
||||
|
||||
def action_to_transition(action: dict[str, Any]) -> EnvTransition: # action_to_transition
|
||||
"""
|
||||
Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey.
|
||||
"""
|
||||
act_dict: dict[str, Any] = {}
|
||||
for k, v in action.items():
|
||||
# Check if the value is a type that should not be converted to a tensor.
|
||||
if isinstance(v, (Rotation, dict)):
|
||||
act_dict[f"{ACTION}.{k}"] = v
|
||||
continue
|
||||
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
act_dict[f"{ACTION}.{k}"] = to_tensor(arr)
|
||||
|
||||
return create_transition(observation={}, action=act_dict)
|
||||
|
||||
|
||||
# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself
|
||||
def observation_to_transition(observation: dict[str, Any]) -> EnvTransition:
|
||||
"""
|
||||
Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey.
|
||||
"""
|
||||
state, images = _split_obs_to_state_and_images(observation)
|
||||
|
||||
obs_dict: dict[str, Any] = {}
|
||||
for k, v in state.items():
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
obs_dict[f"{OBS_STATE}.{k}"] = to_tensor(arr)
|
||||
|
||||
for cam, img in images.items():
|
||||
obs_dict[f"{OBS_IMAGES}.{cam}"] = img
|
||||
|
||||
return create_transition(observation=obs_dict, action={})
|
||||
|
||||
|
||||
def transition_to_robot_action(transition: EnvTransition) -> dict[str, Any]:
|
||||
"""
|
||||
Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions.
|
||||
"""
|
||||
out: dict[str, Any] = {}
|
||||
action_dict = transition.get(TransitionKey.ACTION) or {}
|
||||
|
||||
if action_dict is None:
|
||||
return out
|
||||
|
||||
for k, v in action_dict.items():
|
||||
if isinstance(k, str) and k.startswith(f"{ACTION}.") and k.endswith((".pos", ".vel")):
|
||||
out_key = k[len(f"{ACTION}.") :] # Strip the 'action.' prefix.
|
||||
out[out_key] = float(v)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> EnvTransition:
|
||||
"""Merge multiple transitions or return single transition.
|
||||
|
||||
Args:
|
||||
transitions: Either a single transition or iterable of transitions.
|
||||
|
||||
Returns:
|
||||
Merged EnvTransition.
|
||||
"""
|
||||
|
||||
if not isinstance(transitions, Sequence): # Single transition
|
||||
return transitions
|
||||
|
||||
items = list(transitions)
|
||||
if not items:
|
||||
raise ValueError("merge_transitions() requires a non-empty sequence of transitions")
|
||||
|
||||
result = items[0]
|
||||
for t in items[1:]:
|
||||
result = _merge_transitions(result, t)
|
||||
return result
|
||||
|
||||
|
||||
def transition_to_dataset_frame(
|
||||
transitions_or_transition: EnvTransition | Sequence[EnvTransition], features: dict[str, dict]
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a single EnvTransition or an iterable of them into a flat, dataset-friendly dictionary for training or evaluation.
|
||||
|
||||
Processes transitions according to the provided feature specification and returns
|
||||
data in the format expected by machine learning models and datasets.
|
||||
|
||||
Args:
|
||||
transitions_or_transition: Either a single EnvTransition dict or an iterable of them
|
||||
(which will be merged using merge_transitions).
|
||||
features: Feature specification dictionary with the following structure:
|
||||
- 'action': dict with 'names': list of action feature names
|
||||
- 'observation.state': dict with 'names': list of state feature names
|
||||
- keys starting with 'observation.images.' are passed through as-is
|
||||
|
||||
Returns:
|
||||
Flat dictionary containing:
|
||||
- numpy arrays for "observation.state" and "action" (vectorized from feature names)
|
||||
- any image tensors defined in features (passed through unchanged)
|
||||
- next.{reward,done,truncated} scalar values
|
||||
- info dict
|
||||
- *_is_pad flags and task from complementary_data
|
||||
"""
|
||||
action_names = features.get(ACTION, {}).get("names", [])
|
||||
obs_state_names = features.get(OBS_STATE, {}).get("names", [])
|
||||
image_keys = [k for k in features if k.startswith(OBS_IMAGES)]
|
||||
|
||||
tr = merge_transitions(transitions_or_transition)
|
||||
obs = tr.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
act = tr.get(TransitionKey.ACTION, {}) or {}
|
||||
batch: dict[str, Any] = {}
|
||||
|
||||
# Images passthrough
|
||||
for k in image_keys:
|
||||
if k in obs:
|
||||
batch[k] = obs[k]
|
||||
|
||||
# Observation.state vector
|
||||
if obs_state_names:
|
||||
vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
|
||||
batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Action vector
|
||||
if action_names:
|
||||
vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
|
||||
batch[ACTION] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Add transition metadata
|
||||
if tr.get(TransitionKey.REWARD) is not None:
|
||||
reward_val = _from_tensor(tr[TransitionKey.REWARD])
|
||||
# Check if features expect array format, otherwise keep as scalar
|
||||
if REWARD in features and features[REWARD].get("shape") == (1,):
|
||||
batch[REWARD] = np.array([reward_val], dtype=np.float32)
|
||||
else:
|
||||
batch[REWARD] = reward_val
|
||||
|
||||
if tr.get(TransitionKey.DONE) is not None:
|
||||
done_val = _from_tensor(tr[TransitionKey.DONE])
|
||||
if DONE in features and features[DONE].get("shape") == (1,):
|
||||
batch[DONE] = np.array([done_val], dtype=bool)
|
||||
else:
|
||||
batch[DONE] = done_val
|
||||
|
||||
if tr.get(TransitionKey.TRUNCATED) is not None:
|
||||
truncated_val = _from_tensor(tr[TransitionKey.TRUNCATED])
|
||||
if TRUNCATED in features and features[TRUNCATED].get("shape") == (1,):
|
||||
batch[TRUNCATED] = np.array([truncated_val], dtype=bool)
|
||||
else:
|
||||
batch[TRUNCATED] = truncated_val
|
||||
|
||||
# Complementary data flags and task
|
||||
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if comp:
|
||||
# pad flags
|
||||
for k, v in comp.items():
|
||||
if k.endswith("_is_pad"):
|
||||
batch[k] = v
|
||||
# task label
|
||||
if comp.get("task") is not None:
|
||||
batch["task"] = comp["task"]
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
||||
"""Convert a batch dict coming from LeRobot replay/dataset code into an EnvTransition dictionary.
|
||||
|
||||
The function maps well known keys to the EnvTransition structure. Missing keys are
|
||||
filled with sane defaults (None or 0.0/False).
|
||||
|
||||
Keys recognised (case-sensitive):
|
||||
* "observation.*" (keys starting with "observation." are grouped into observation dict)
|
||||
* "action"
|
||||
* "next.reward"
|
||||
* "next.done"
|
||||
* "next.truncated"
|
||||
* "info"
|
||||
* "_is_pad" patterns (padding flags)
|
||||
* "task", "index", "task_index" (complementary data)
|
||||
|
||||
Additional keys are ignored so that existing dataloaders can carry extra
|
||||
metadata without breaking the processor.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary from datasets or dataloaders containing the above keys.
|
||||
|
||||
Returns:
|
||||
EnvTransition dictionary with properly structured transition data.
|
||||
"""
|
||||
|
||||
# Validate input type
|
||||
if not isinstance(batch, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
|
||||
|
||||
# Extract observation keys
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
complementary_data = _extract_complementary_data(batch)
|
||||
|
||||
return create_transition(
|
||||
observation=observation_keys if observation_keys else None,
|
||||
action=batch.get("action"),
|
||||
reward=batch.get("next.reward", 0.0),
|
||||
done=batch.get("next.done", False),
|
||||
truncated=batch.get("next.truncated", False),
|
||||
info=batch.get("info", {}),
|
||||
complementary_data=complementary_data if complementary_data else None,
|
||||
)
|
||||
|
||||
|
||||
def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
"""Inverse of batch_to_transition. Returns a dict with canonical field names used throughout LeRobot.
|
||||
|
||||
Converts an EnvTransition back to the batch format expected by datasets, dataloaders,
|
||||
and other LeRobot components.
|
||||
|
||||
Output format:
|
||||
* "action": Action data from transition
|
||||
* "next.reward": Reward value (defaults to 0.0)
|
||||
* "next.done": Done flag (defaults to False)
|
||||
* "next.truncated": Truncated flag (defaults to False)
|
||||
* "info": Info dictionary (defaults to {})
|
||||
* Flattened observation keys (e.g., "observation.state", "observation.images.cam1")
|
||||
* Complementary data fields ("task", "index", "task_index", padding flags)
|
||||
|
||||
Args:
|
||||
transition: EnvTransition dictionary to convert.
|
||||
|
||||
Returns:
|
||||
Batch dictionary with canonical LeRobot field names suitable for dataloaders.
|
||||
"""
|
||||
batch = {
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||
"next.done": transition.get(TransitionKey.DONE, False),
|
||||
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
|
||||
"info": transition.get(TransitionKey.INFO, {}),
|
||||
}
|
||||
|
||||
# Add complementary data
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if comp_data:
|
||||
batch.update(comp_data)
|
||||
|
||||
# Flatten observation dict
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if isinstance(observation, dict):
|
||||
batch.update(observation)
|
||||
|
||||
return batch
|
||||
@@ -1,49 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TransitionKey(str, Enum):
|
||||
"""Keys for accessing EnvTransition dictionary components."""
|
||||
|
||||
# TODO(Steven): Use consts
|
||||
OBSERVATION = "observation"
|
||||
ACTION = "action"
|
||||
REWARD = "reward"
|
||||
DONE = "done"
|
||||
TRUNCATED = "truncated"
|
||||
INFO = "info"
|
||||
COMPLEMENTARY_DATA = "complementary_data"
|
||||
|
||||
|
||||
EnvTransition = TypedDict(
|
||||
"EnvTransition",
|
||||
{
|
||||
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
|
||||
TransitionKey.ACTION.value: Any | torch.Tensor | None,
|
||||
TransitionKey.REWARD.value: float | torch.Tensor | None,
|
||||
TransitionKey.DONE.value: bool | torch.Tensor | None,
|
||||
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
|
||||
TransitionKey.INFO.value: dict[str, Any] | None,
|
||||
TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
|
||||
},
|
||||
)
|
||||
@@ -1,145 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
||||
@dataclass
|
||||
class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
||||
"""
|
||||
Map a tensor to a delta action dictionary.
|
||||
"""
|
||||
|
||||
use_gripper: bool = True
|
||||
|
||||
def action(self, action: Tensor) -> dict:
|
||||
if action.dim() > 1:
|
||||
action = action.squeeze(0)
|
||||
|
||||
# TODO (maractingi): add rotation
|
||||
delta_action = {
|
||||
f"{ACTION}.delta_x": action[0],
|
||||
f"{ACTION}.delta_y": action[1],
|
||||
f"{ACTION}.delta_z": action[2],
|
||||
}
|
||||
if self.use_gripper:
|
||||
delta_action[f"{ACTION}.gripper"] = action[3]
|
||||
return delta_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[f"{ACTION}.delta_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.delta_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.delta_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
if self.use_gripper:
|
||||
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
|
||||
@dataclass
|
||||
class MapDeltaActionToRobotActionStep(ActionProcessorStep):
|
||||
"""
|
||||
Map delta actions from teleoperators (gamepad, keyboard) to robot target actions
|
||||
for use with inverse kinematics processors.
|
||||
|
||||
Expected input ACTION keys:
|
||||
{
|
||||
"action.delta_x": float,
|
||||
"action.delta_y": float,
|
||||
"action.delta_z": float,
|
||||
"action.gripper": float (optional),
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.enabled": bool,
|
||||
"action.target_x": float,
|
||||
"action.target_y": float,
|
||||
"action.target_z": float,
|
||||
"action.target_wx": float,
|
||||
"action.target_wy": float,
|
||||
"action.target_wz": float,
|
||||
"action.gripper": float,
|
||||
}
|
||||
"""
|
||||
|
||||
# Scale factors for delta movements
|
||||
position_scale: float = 1.0
|
||||
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
|
||||
noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
|
||||
|
||||
def action(self, action: dict) -> dict:
|
||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
||||
delta_x = action.pop(f"{ACTION}.delta_x", 0.0)
|
||||
delta_y = action.pop(f"{ACTION}.delta_y", 0.0)
|
||||
delta_z = action.pop(f"{ACTION}.delta_z", 0.0)
|
||||
gripper = action.pop(f"{ACTION}.gripper", 1.0) # Default to "stay" (1.0)
|
||||
|
||||
# Determine if the teleoperator is actively providing input
|
||||
# Consider enabled if any significant movement delta is detected
|
||||
position_magnitude = (delta_x**2 + delta_y**2 + delta_z**2) ** 0.5 # Use Euclidean norm for position
|
||||
enabled = position_magnitude > self.noise_threshold # Small threshold to avoid noise
|
||||
|
||||
# Scale the deltas appropriately
|
||||
scaled_delta_x = delta_x * self.position_scale
|
||||
scaled_delta_y = delta_y * self.position_scale
|
||||
scaled_delta_z = delta_z * self.position_scale
|
||||
|
||||
# For gamepad/keyboard, we don't have rotation input, so set to 0
|
||||
# These could be extended in the future for more sophisticated teleoperators
|
||||
target_wx = 0.0
|
||||
target_wy = 0.0
|
||||
target_wz = 0.0
|
||||
|
||||
# Update action with robot target format
|
||||
action = {
|
||||
f"{ACTION}.enabled": enabled,
|
||||
f"{ACTION}.target_x": scaled_delta_x,
|
||||
f"{ACTION}.target_y": scaled_delta_y,
|
||||
f"{ACTION}.target_z": scaled_delta_z,
|
||||
f"{ACTION}.target_wx": target_wx,
|
||||
f"{ACTION}.target_wy": target_wy,
|
||||
f"{ACTION}.target_wz": target_wz,
|
||||
f"{ACTION}.gripper": float(gripper),
|
||||
}
|
||||
|
||||
return action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transform features to match output format."""
|
||||
features.pop(f"{ACTION}.delta_x", None)
|
||||
features.pop(f"{ACTION}.delta_y", None)
|
||||
features.pop(f"{ACTION}.delta_z", None)
|
||||
features.pop(f"{ACTION}.gripper", None)
|
||||
|
||||
features[f"{ACTION}.enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
@@ -19,116 +19,64 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("device_processor")
|
||||
@dataclass
|
||||
class DeviceProcessorStep(ProcessorStep):
|
||||
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
|
||||
class DeviceProcessor:
|
||||
"""Processes transitions by moving tensors to the specified device.
|
||||
|
||||
This processor ensures that all tensors in the transition are moved to the
|
||||
specified device (CPU or GPU) before they are returned. It can also convert
|
||||
floating-point tensors to a specified dtype while preserving non-float types
|
||||
(int, long, bool, etc.).
|
||||
specified device (CPU or GPU) before they are returned.
|
||||
"""
|
||||
|
||||
device: str = "cpu"
|
||||
float_dtype: str | None = None
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"half": torch.float16,
|
||||
"float": torch.float32,
|
||||
"double": torch.float64,
|
||||
}
|
||||
device: torch.device = "cpu"
|
||||
|
||||
def __post_init__(self):
|
||||
self.tensor_device: torch.device = get_safe_torch_device(self.device)
|
||||
self.device = self.tensor_device.type # cuda might have changed to cuda:1
|
||||
self.device = get_safe_torch_device(self.device)
|
||||
self.non_blocking = "cuda" in str(self.device)
|
||||
|
||||
# Validate and convert float_dtype string to torch dtype
|
||||
if self.float_dtype is not None:
|
||||
if self.float_dtype not in self.DTYPE_MAPPING:
|
||||
raise ValueError(
|
||||
f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}"
|
||||
)
|
||||
|
||||
self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype]
|
||||
else:
|
||||
self._target_float_dtype = None
|
||||
|
||||
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Process a tensor by moving to device and optionally converting float dtype.
|
||||
|
||||
If the tensor is already on a GPU and we're configured for a GPU, it preserves
|
||||
that GPU placement (useful for multi-GPU training with Accelerate).
|
||||
Otherwise, it moves to the configured device.
|
||||
"""
|
||||
# Determine target device
|
||||
if tensor.is_cuda and self.tensor_device.type == "cuda":
|
||||
# Both tensor and target are on GPU - preserve tensor's GPU placement
|
||||
# This handles multi-GPU scenarios where Accelerate has already placed
|
||||
# tensors on the correct GPU for each process
|
||||
target_device = tensor.device
|
||||
else:
|
||||
# Either tensor is on CPU, or we're configured for CPU
|
||||
# In both cases, use the configured device
|
||||
target_device = self.tensor_device
|
||||
|
||||
# Only move if necessary
|
||||
if tensor.device != target_device:
|
||||
tensor = tensor.to(target_device, non_blocking=self.non_blocking)
|
||||
|
||||
# Convert float dtype if specified and tensor is floating point
|
||||
if self._target_float_dtype is not None and tensor.is_floating_point():
|
||||
tensor = tensor.to(dtype=self._target_float_dtype)
|
||||
|
||||
return tensor
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Create a copy of the transition
|
||||
new_transition = transition.copy()
|
||||
|
||||
simple_tensor_keys = [
|
||||
TransitionKey.ACTION,
|
||||
TransitionKey.REWARD,
|
||||
TransitionKey.DONE,
|
||||
TransitionKey.TRUNCATED,
|
||||
]
|
||||
# Process observation tensors
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_observation = {
|
||||
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in observation.items()
|
||||
}
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
|
||||
dict_tensor_keys = [
|
||||
TransitionKey.OBSERVATION,
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
]
|
||||
# Process action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, torch.Tensor):
|
||||
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process simple tensors
|
||||
for key in simple_tensor_keys:
|
||||
value = transition.get(key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
new_transition[key] = self._process_tensor(value)
|
||||
# Process reward tensor
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
if reward is not None and isinstance(reward, torch.Tensor):
|
||||
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process dictionary-like tensors
|
||||
for key in dict_tensor_keys:
|
||||
data_dict = transition.get(key)
|
||||
if data_dict is not None:
|
||||
new_data_dict = {
|
||||
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in data_dict.items()
|
||||
}
|
||||
new_transition[key] = new_data_dict
|
||||
# Process done tensor
|
||||
done = transition.get(TransitionKey.DONE)
|
||||
if done is not None and isinstance(done, torch.Tensor):
|
||||
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process truncated tensor
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is not None and isinstance(truncated, torch.Tensor):
|
||||
new_transition[TransitionKey.TRUNCATED] = truncated.to(
|
||||
self.device, non_blocking=self.non_blocking
|
||||
)
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {"device": self.device, "float_dtype": self.float_dtype}
|
||||
return {"device": self.device}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
#! /usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
|
||||
from .converters import to_tensor
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("torch2numpy_action_processor")
|
||||
@dataclass
|
||||
class Torch2NumpyActionProcessorStep(ActionProcessorStep):
|
||||
"""Convert PyTorch tensor actions to NumPy arrays."""
|
||||
|
||||
squeeze_batch_dim: bool = True
|
||||
|
||||
def action(self, action: torch.Tensor) -> np.ndarray:
|
||||
if not isinstance(action, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Expected torch.Tensor or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
|
||||
numpy_action = action.detach().cpu().numpy()
|
||||
|
||||
# Remove batch dimensions but preserve action dimensions
|
||||
# Only squeeze if there's a batch dimension (first dim == 1)
|
||||
if (
|
||||
self.squeeze_batch_dim
|
||||
and numpy_action.shape
|
||||
and len(numpy_action.shape) > 1
|
||||
and numpy_action.shape[0] == 1
|
||||
):
|
||||
numpy_action = numpy_action.squeeze(0)
|
||||
|
||||
return numpy_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("numpy2torch_action_processor")
|
||||
@dataclass
|
||||
class Numpy2TorchActionProcessorStep(ActionProcessorStep):
|
||||
"""Convert NumPy array action to PyTorch tensor."""
|
||||
|
||||
def action(self, action: np.ndarray) -> torch.Tensor:
|
||||
if not isinstance(action, np.ndarray):
|
||||
raise TypeError(
|
||||
f"Expected np.ndarray or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
return torch_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -1,332 +0,0 @@
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import (
|
||||
ComplementaryDataProcessorStep,
|
||||
InfoProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TruncatedProcessorStep,
|
||||
)
|
||||
|
||||
GRIPPER_KEY = "gripper"
|
||||
DISCRETE_PENALTY_KEY = "discrete_penalty"
|
||||
TELEOP_ACTION_KEY = "teleop_action"
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
|
||||
@dataclass
|
||||
class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
"""Add teleoperator action to transition complementary data."""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_info")
|
||||
@dataclass
|
||||
class AddTeleopEventsAsInfoStep(InfoProcessorStep):
|
||||
"""Add teleoperator control events to transition info."""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def info(self, info: dict) -> dict:
|
||||
new_info = dict(info)
|
||||
teleop_events = getattr(self.teleop_device, "get_teleop_events", lambda: {})()
|
||||
new_info.update(teleop_events)
|
||||
return new_info
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||
@dataclass
|
||||
class ImageCropResizeProcessorStep(ObservationProcessorStep):
|
||||
"""Crop and resize image observations."""
|
||||
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
if self.resize_size is None and not self.crop_params_dict:
|
||||
return observation
|
||||
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Process all image keys in the observation
|
||||
for key in observation:
|
||||
if "image" not in key:
|
||||
continue
|
||||
|
||||
image = observation[key]
|
||||
device = image.device
|
||||
# NOTE (maractingi): No mps kernel for crop and resize, so we need to move to cpu
|
||||
if device.type == "mps":
|
||||
image = image.cpu()
|
||||
# Crop if crop params are provided for this key
|
||||
if self.crop_params_dict is not None and key in self.crop_params_dict:
|
||||
crop_params = self.crop_params_dict[key]
|
||||
image = F.crop(image, *crop_params)
|
||||
if self.resize_size is not None:
|
||||
image = F.resize(image, self.resize_size)
|
||||
image = image.clamp(0.0, 1.0)
|
||||
new_observation[key] = image.to(device)
|
||||
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"crop_params_dict": self.crop_params_dict,
|
||||
"resize_size": self.resize_size,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if self.resize_size is None:
|
||||
return features
|
||||
for key in features:
|
||||
if "image" in key:
|
||||
features[key] = PolicyFeature(type=features[key].type, shape=self.resize_size)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("time_limit_processor")
|
||||
class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
"""Track episode steps and enforce time limits."""
|
||||
|
||||
max_episode_steps: int
|
||||
current_step: int = 0
|
||||
|
||||
def truncated(self, truncated):
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
truncated = True
|
||||
# TODO (steven): missing an else truncated = False?
|
||||
return truncated
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"max_episode_steps": self.max_episode_steps,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_step = 0
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
"""Apply penalty for inappropriate gripper usage."""
|
||||
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
"""Calculate gripper penalty and add to complementary data."""
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
|
||||
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
|
||||
if current_gripper_pos is None:
|
||||
return complementary_data
|
||||
|
||||
gripper_action = action[f"{ACTION}.{GRIPPER_KEY}.pos"]
|
||||
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
||||
|
||||
# Normalize gripper state and action
|
||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||
|
||||
# Calculate penalty boolean as in original
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
||||
)
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Create new complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the processor state."""
|
||||
self.last_gripper_state = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||
class InterventionActionProcessorStep(ProcessorStep):
|
||||
"""Handle human intervention actions and episode termination."""
|
||||
|
||||
use_gripper: bool = False
|
||||
terminate_on_success: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
|
||||
# Get intervention signals from complementary data
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {})
|
||||
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
|
||||
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
|
||||
success = info.get(TeleopEvents.SUCCESS, False)
|
||||
rerecord_episode = info.get(TeleopEvents.RERECORD_EPISODE, False)
|
||||
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Override action if intervention is active
|
||||
if is_intervention and teleop_action is not None:
|
||||
if isinstance(teleop_action, dict):
|
||||
# Convert teleop_action dict to tensor format
|
||||
action_list = [
|
||||
teleop_action.get(f"{ACTION}.delta_x", 0.0),
|
||||
teleop_action.get(f"{ACTION}.delta_y", 0.0),
|
||||
teleop_action.get(f"{ACTION}.delta_z", 0.0),
|
||||
]
|
||||
if self.use_gripper:
|
||||
action_list.append(teleop_action.get(GRIPPER_KEY, 1.0))
|
||||
elif isinstance(teleop_action, np.ndarray):
|
||||
action_list = teleop_action.tolist()
|
||||
else:
|
||||
action_list = teleop_action
|
||||
|
||||
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
|
||||
new_transition[TransitionKey.ACTION] = teleop_action_tensor
|
||||
|
||||
# Handle episode termination
|
||||
new_transition[TransitionKey.DONE] = bool(terminate_episode) or (
|
||||
self.terminate_on_success and success
|
||||
)
|
||||
new_transition[TransitionKey.REWARD] = float(success)
|
||||
|
||||
# Update info with intervention metadata
|
||||
info = new_transition.get(TransitionKey.INFO, {})
|
||||
info[TeleopEvents.IS_INTERVENTION] = is_intervention
|
||||
info[TeleopEvents.RERECORD_EPISODE] = rerecord_episode
|
||||
info[TeleopEvents.SUCCESS] = success
|
||||
new_transition[TransitionKey.INFO] = info
|
||||
|
||||
# Update complementary data with teleop action
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"use_gripper": self.use_gripper,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("reward_classifier_processor")
|
||||
class RewardClassifierProcessorStep(ProcessorStep):
|
||||
"""Apply reward classification to image observations."""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
device: str = "cpu"
|
||||
success_threshold: float = 0.5
|
||||
success_reward: float = 1.0
|
||||
terminate_on_success: bool = True
|
||||
|
||||
reward_classifier: Any = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the reward classifier after dataclass initialization."""
|
||||
if self.pretrained_path is not None:
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
|
||||
self.reward_classifier.to(self.device)
|
||||
self.reward_classifier.eval()
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None or self.reward_classifier is None:
|
||||
return new_transition
|
||||
|
||||
# Extract images from observation
|
||||
images = {key: value for key, value in observation.items() if "image" in key}
|
||||
|
||||
if not images:
|
||||
return new_transition
|
||||
|
||||
# Run reward classifier
|
||||
start_time = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
success = self.reward_classifier.predict_reward(images, threshold=self.success_threshold)
|
||||
|
||||
classifier_frequency = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
# Calculate reward and termination
|
||||
reward = new_transition.get(TransitionKey.REWARD, 0.0)
|
||||
terminated = new_transition.get(TransitionKey.DONE, False)
|
||||
|
||||
if math.isclose(success, 1, abs_tol=1e-2):
|
||||
reward = self.success_reward
|
||||
if self.terminate_on_success:
|
||||
terminated = True
|
||||
|
||||
# Update transition
|
||||
new_transition[TransitionKey.REWARD] = reward
|
||||
new_transition[TransitionKey.DONE] = terminated
|
||||
|
||||
# Update info with classifier frequency
|
||||
info = new_transition.get(TransitionKey.INFO, {})
|
||||
info["reward_classifier_frequency"] = classifier_frequency
|
||||
new_transition[TransitionKey.INFO] = info
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"device": self.device,
|
||||
"success_threshold": self.success_threshold,
|
||||
"success_reward": self.success_reward,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -1,109 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_STATE
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||
class JointVelocityProcessorStep(ObservationProcessorStep):
|
||||
"""Add joint velocity information to observations."""
|
||||
|
||||
dt: float = 0.1
|
||||
|
||||
last_joint_positions: torch.Tensor | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
# Get current joint positions (assuming they're in observation.state)
|
||||
current_positions = observation.get(OBS_STATE)
|
||||
if current_positions is None:
|
||||
# TODO(steven): if we get here, then the transform_features method will not hold
|
||||
raise ValueError(f"{OBS_STATE} is not in observation")
|
||||
|
||||
# Initialize last joint positions if not already set
|
||||
if self.last_joint_positions is None:
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
joint_velocities = torch.zeros_like(current_positions)
|
||||
else:
|
||||
# Compute velocities
|
||||
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
|
||||
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
|
||||
# Extend observation with velocities
|
||||
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation[OBS_STATE] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"dt": self.dt,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_joint_positions = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if OBS_STATE in features:
|
||||
original_feature = features[OBS_STATE]
|
||||
# Double the shape to account for positions + velocities
|
||||
new_shape = (original_feature.shape[0] * 2,) + original_feature.shape[1:]
|
||||
|
||||
features[OBS_STATE] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("current_processor")
|
||||
class MotorCurrentProcessorStep(ObservationProcessorStep):
|
||||
"""Add motor current information to observations."""
|
||||
|
||||
robot: Robot | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
# Get current values from robot state
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot is not set")
|
||||
|
||||
present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined]
|
||||
motor_currents = torch.tensor(
|
||||
[present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined]
|
||||
dtype=torch.float32,
|
||||
).unsqueeze(0)
|
||||
|
||||
current_state = observation.get(OBS_STATE)
|
||||
if current_state is None:
|
||||
return observation
|
||||
|
||||
extended_state = torch.cat([current_state, motor_currents], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation[OBS_STATE] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if OBS_STATE in features and self.robot is not None:
|
||||
original_feature = features[OBS_STATE]
|
||||
# Add motor current dimensions to the original state shape
|
||||
num_motors = 0
|
||||
if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined]
|
||||
num_motors = len(self.robot.bus.motors) # type: ignore[attr-defined]
|
||||
|
||||
if num_motors > 0:
|
||||
new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:]
|
||||
features[OBS_STATE] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
return features
|
||||
@@ -1,503 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Generic script to migrate any policy model with normalization layers to the new pipeline-based system.
|
||||
|
||||
This script:
|
||||
1. Loads an existing pretrained policy model
|
||||
2. Extracts normalization statistics from the model
|
||||
3. Creates both preprocessor and postprocessor:
|
||||
- Preprocessor: normalizes both inputs (observations) and outputs (actions) for training
|
||||
- Postprocessor: unnormalizes outputs (actions) for inference
|
||||
4. Removes normalization layers from the model state_dict
|
||||
5. Saves the new model and both processors
|
||||
|
||||
Usage:
|
||||
python src/lerobot/processor/migrate_policy_normalization.py \
|
||||
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--policy-type act \
|
||||
--push-to-hub
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from .pipeline import PolicyProcessorPipeline
|
||||
from .rename_processor import RenameProcessorStep
|
||||
|
||||
# Policy type to class mapping
|
||||
POLICY_CLASSES = {
|
||||
"act": "lerobot.policies.act.modeling_act.ACTPolicy",
|
||||
"diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy",
|
||||
"pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy",
|
||||
"pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy",
|
||||
"smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy",
|
||||
"tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy",
|
||||
"vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy",
|
||||
"sac": "lerobot.policies.sac.modeling_sac.SACPolicy",
|
||||
"classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy",
|
||||
}
|
||||
|
||||
|
||||
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Extract normalization statistics from model state_dict."""
|
||||
stats = {}
|
||||
|
||||
# Define patterns to match and their prefixes to remove
|
||||
normalization_patterns = [
|
||||
"normalize_inputs.buffer_",
|
||||
"unnormalize_outputs.buffer_",
|
||||
"normalize_targets.buffer_",
|
||||
"normalize.", # Must come after normalize_* patterns
|
||||
"unnormalize.", # Must come after unnormalize_* patterns
|
||||
"input_normalizer.",
|
||||
"output_normalizer.",
|
||||
]
|
||||
|
||||
# Process each key in state_dict
|
||||
for key, tensor in state_dict.items():
|
||||
# Try each pattern
|
||||
for pattern in normalization_patterns:
|
||||
if key.startswith(pattern):
|
||||
# Extract the remaining part after the pattern
|
||||
remaining = key[len(pattern) :]
|
||||
parts = remaining.split(".")
|
||||
|
||||
# Need at least feature name and stat type
|
||||
if len(parts) >= 2:
|
||||
# Last part is the stat type (mean, std, min, max, etc.)
|
||||
stat_type = parts[-1]
|
||||
# Everything else is the feature name
|
||||
feature_name = ".".join(parts[:-1]).replace("_", ".")
|
||||
|
||||
# Add to stats
|
||||
if feature_name not in stats:
|
||||
stats[feature_name] = {}
|
||||
stats[feature_name][stat_type] = tensor.clone()
|
||||
|
||||
# Only process the first matching pattern
|
||||
break
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def detect_features_and_norm_modes(
|
||||
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
|
||||
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
|
||||
"""Detect features and normalization modes from config and stats."""
|
||||
features = {}
|
||||
norm_modes = {}
|
||||
|
||||
# First, check if there's a normalization_mapping in the config
|
||||
if "normalization_mapping" in config:
|
||||
print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
|
||||
# Extract normalization modes from config
|
||||
for feature_name, mode_str in config["normalization_mapping"].items():
|
||||
# Convert string to NormalizationMode enum
|
||||
if mode_str == "mean_std":
|
||||
mode = NormalizationMode.MEAN_STD
|
||||
elif mode_str == "min_max":
|
||||
mode = NormalizationMode.MIN_MAX
|
||||
else:
|
||||
print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'")
|
||||
continue
|
||||
|
||||
# Determine feature type from feature name
|
||||
if "image" in feature_name or "visual" in feature_name:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in feature_name:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in feature_name:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
norm_modes[feature_type] = mode
|
||||
|
||||
# Try to extract from config
|
||||
if "features" in config:
|
||||
for key, feature_config in config["features"].items():
|
||||
shape = feature_config.get("shape", feature_config.get("dim"))
|
||||
shape = (shape,) if isinstance(shape, int) else tuple(shape)
|
||||
|
||||
# Determine feature type
|
||||
if "image" in key or "visual" in key:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE # Default
|
||||
|
||||
features[key] = PolicyFeature(feature_type, shape)
|
||||
|
||||
# If no features in config, infer from stats
|
||||
if not features:
|
||||
for key, stat_dict in stats.items():
|
||||
# Get shape from any stat tensor
|
||||
tensor = next(iter(stat_dict.values()))
|
||||
shape = tuple(tensor.shape)
|
||||
|
||||
# Determine feature type based on key
|
||||
if "image" in key or "visual" in key or "pixels" in key:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key or "joint" in key or "position" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
features[key] = PolicyFeature(feature_type, shape)
|
||||
|
||||
# If normalization modes weren't in config, determine based on available stats
|
||||
if not norm_modes:
|
||||
for key, stat_dict in stats.items():
|
||||
if key in features:
|
||||
if "mean" in stat_dict and "std" in stat_dict:
|
||||
feature_type = features[key].type
|
||||
if feature_type not in norm_modes:
|
||||
norm_modes[feature_type] = NormalizationMode.MEAN_STD
|
||||
elif "min" in stat_dict and "max" in stat_dict:
|
||||
feature_type = features[key].type
|
||||
if feature_type not in norm_modes:
|
||||
norm_modes[feature_type] = NormalizationMode.MIN_MAX
|
||||
|
||||
# Default normalization modes if not detected
|
||||
if FeatureType.VISUAL not in norm_modes:
|
||||
norm_modes[FeatureType.VISUAL] = NormalizationMode.MEAN_STD
|
||||
if FeatureType.STATE not in norm_modes:
|
||||
norm_modes[FeatureType.STATE] = NormalizationMode.MIN_MAX
|
||||
if FeatureType.ACTION not in norm_modes:
|
||||
norm_modes[FeatureType.ACTION] = NormalizationMode.MEAN_STD
|
||||
|
||||
return features, norm_modes
|
||||
|
||||
|
||||
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Remove normalization layers from state_dict."""
|
||||
new_state_dict = {}
|
||||
|
||||
# Patterns to remove
|
||||
remove_patterns = [
|
||||
"normalize_inputs.",
|
||||
"unnormalize_outputs.",
|
||||
"normalize_targets.", # Added pattern for target normalization
|
||||
"normalize.",
|
||||
"unnormalize.",
|
||||
"input_normalizer.",
|
||||
"output_normalizer.",
|
||||
"normalizer.",
|
||||
]
|
||||
|
||||
for key, tensor in state_dict.items():
|
||||
should_remove = any(pattern in key for pattern in remove_patterns)
|
||||
if not should_remove:
|
||||
new_state_dict[key] = tensor
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""Convert features from old format to PolicyFeature objects."""
|
||||
converted_features = {}
|
||||
|
||||
for key, feature_dict in features_dict.items():
|
||||
# Determine feature type based on key
|
||||
if "image" in key or "visual" in key:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
# Get shape from feature dict
|
||||
shape = feature_dict.get("shape", feature_dict.get("dim"))
|
||||
shape = (shape,) if isinstance(shape, int) else tuple(shape)
|
||||
|
||||
converted_features[key] = PolicyFeature(feature_type, shape)
|
||||
|
||||
return converted_features
|
||||
|
||||
|
||||
def load_model_from_hub(
|
||||
repo_id: str, revision: str = None
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
||||
"""Load model state_dict and config from hub."""
|
||||
# Download files
|
||||
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
||||
|
||||
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
||||
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
|
||||
|
||||
# Load state_dict
|
||||
state_dict = load_safetensors(safetensors_path)
|
||||
|
||||
# Load config
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
with open(train_config_path) as f:
|
||||
train_config = json.load(f)
|
||||
|
||||
return state_dict, config, train_config
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Migrate policy models with normalization layers to new pipeline system"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model (hub repo or local directory)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output directory for migrated model (default: same as pretrained-path)",
|
||||
)
|
||||
parser.add_argument("--push-to-hub", action="store_true", help="Push migrated model to hub")
|
||||
parser.add_argument(
|
||||
"--hub-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hub repository ID for pushing (default: same as pretrained-path)",
|
||||
)
|
||||
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
|
||||
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load model and config
|
||||
print(f"Loading model from {args.pretrained_path}...")
|
||||
if os.path.isdir(args.pretrained_path):
|
||||
# Local directory
|
||||
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
|
||||
with open(os.path.join(args.pretrained_path, "config.json")) as f:
|
||||
config = json.load(f)
|
||||
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
|
||||
train_config = json.load(f)
|
||||
else:
|
||||
# Hub repository
|
||||
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
|
||||
|
||||
# Extract normalization statistics
|
||||
print("Extracting normalization statistics...")
|
||||
stats = extract_normalization_stats(state_dict)
|
||||
|
||||
print(f"Found normalization statistics for: {list(stats.keys())}")
|
||||
|
||||
# Detect input features and normalization modes
|
||||
print("Detecting features and normalization modes...")
|
||||
features, norm_map = detect_features_and_norm_modes(config, stats)
|
||||
|
||||
print(f"Detected features: {list(features.keys())}")
|
||||
print(f"Normalization modes: {norm_map}")
|
||||
|
||||
# Remove normalization layers from state_dict
|
||||
print("Removing normalization layers from model...")
|
||||
new_state_dict = remove_normalization_layers(state_dict)
|
||||
|
||||
removed_keys = set(state_dict.keys()) - set(new_state_dict.keys())
|
||||
if removed_keys:
|
||||
print(f"Removed {len(removed_keys)} normalization layer keys")
|
||||
|
||||
# Determine output path
|
||||
if args.output_dir:
|
||||
output_dir = Path(args.output_dir)
|
||||
else:
|
||||
if os.path.isdir(args.pretrained_path):
|
||||
output_dir = Path(args.pretrained_path).parent / f"{Path(args.pretrained_path).name}_migrated"
|
||||
else:
|
||||
output_dir = Path(f"./{args.pretrained_path.replace('/', '_')}_migrated")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Clean up config - remove normalization_mapping field
|
||||
cleaned_config = dict(config)
|
||||
if "normalization_mapping" in cleaned_config:
|
||||
print("Removing 'normalization_mapping' field from config")
|
||||
del cleaned_config["normalization_mapping"]
|
||||
policy_type = deepcopy(cleaned_config["type"])
|
||||
|
||||
del cleaned_config["type"]
|
||||
|
||||
# Instantiate the policy model with cleaned config and load the cleaned state dict
|
||||
print(f"Instantiating {policy_type} policy model...")
|
||||
policy_class_path = POLICY_CLASSES[policy_type]
|
||||
module_path, class_name = policy_class_path.rsplit(".", 1)
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
policy_class = getattr(module, class_name)
|
||||
|
||||
# Create config class instance
|
||||
config_module_path = module_path.replace("modeling", "configuration")
|
||||
config_module = importlib.import_module(config_module_path)
|
||||
# Handle special cases for config class names
|
||||
config_class_names = {
|
||||
"act": "ACTConfig",
|
||||
"diffusion": "DiffusionConfig",
|
||||
"pi0": "PI0Config",
|
||||
"pi0fast": "PI0FASTConfig",
|
||||
"smolvla": "SmolVLAConfig",
|
||||
"tdmpc": "TDMPCConfig",
|
||||
"vqbet": "VQBeTConfig",
|
||||
"sac": "SACConfig",
|
||||
"classifier": "ClassifierConfig",
|
||||
}
|
||||
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
|
||||
config_class = getattr(config_module, config_class_name)
|
||||
|
||||
# Convert input_features and output_features to PolicyFeature objects - these are mandatory
|
||||
if "input_features" not in cleaned_config:
|
||||
raise ValueError("Missing mandatory 'input_features' in config")
|
||||
if "output_features" not in cleaned_config:
|
||||
raise ValueError("Missing mandatory 'output_features' in config")
|
||||
|
||||
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"])
|
||||
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"])
|
||||
|
||||
# Create config instance from cleaned config dict
|
||||
policy_config = config_class(**cleaned_config)
|
||||
|
||||
# Create policy instance - some policies expect dataset_stats
|
||||
policy = policy_class(policy_config)
|
||||
|
||||
# Load the cleaned state dict
|
||||
policy.load_state_dict(new_state_dict, strict=True)
|
||||
print("Successfully loaded cleaned state dict into policy model")
|
||||
|
||||
# Now create preprocessor and postprocessor with cleaned_config available
|
||||
print("Creating preprocessor and postprocessor...")
|
||||
# The pattern from existing processor factories:
|
||||
# - Preprocessor has two NormalizerProcessorSteps: one for input_features, one for output_features
|
||||
# - Postprocessor has one UnnormalizerProcessorStep for output_features only
|
||||
|
||||
# Get features from cleaned_config (now they're PolicyFeature objects)
|
||||
input_features = cleaned_config.get("input_features", {})
|
||||
output_features = cleaned_config.get("output_features", {})
|
||||
|
||||
# Create preprocessor with two normalizers (following the pattern from processor factories)
|
||||
preprocessor_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**input_features, **output_features},
|
||||
norm_map=norm_map,
|
||||
stats=stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=policy_config.device),
|
||||
]
|
||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps, name="robot_preprocessor")
|
||||
|
||||
# Create postprocessor with unnormalizer for outputs only
|
||||
postprocessor_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(features=output_features, norm_map=norm_map, stats=stats),
|
||||
]
|
||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps, name="robot_postprocessor")
|
||||
|
||||
# Determine hub repo ID if pushing to hub
|
||||
if args.push_to_hub:
|
||||
if args.hub_repo_id:
|
||||
hub_repo_id = args.hub_repo_id
|
||||
else:
|
||||
if not os.path.isdir(args.pretrained_path):
|
||||
# Use same repo with "_migrated" suffix
|
||||
hub_repo_id = f"{args.pretrained_path}_migrated"
|
||||
else:
|
||||
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
|
||||
else:
|
||||
hub_repo_id = None
|
||||
|
||||
# Save preprocessor and postprocessor to root directory
|
||||
print(f"Saving preprocessor to {output_dir}...")
|
||||
preprocessor.save_pretrained(output_dir)
|
||||
if args.push_to_hub:
|
||||
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||
|
||||
print(f"Saving postprocessor to {output_dir}...")
|
||||
postprocessor.save_pretrained(output_dir)
|
||||
if args.push_to_hub:
|
||||
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||
|
||||
# Save model using the policy's save_pretrained method
|
||||
print(f"Saving model to {output_dir}...")
|
||||
policy.save_pretrained(
|
||||
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
|
||||
)
|
||||
|
||||
# Generate and save model card
|
||||
print("Generating model card...")
|
||||
# Get metadata from original config
|
||||
dataset_repo_id = train_config.get("repo_id", "unknown")
|
||||
license = config.get("license", "apache-2.0")
|
||||
|
||||
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
|
||||
tags = set(tags).union({"robotics", "lerobot", policy_type})
|
||||
tags = list(tags)
|
||||
|
||||
# Generate model card
|
||||
card = policy.generate_model_card(
|
||||
dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags
|
||||
)
|
||||
|
||||
# Save model card locally
|
||||
card.save(str(output_dir / "README.md"))
|
||||
print(f"Model card saved to {output_dir / 'README.md'}")
|
||||
# Push model card to hub if requested
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_dir / "README.md"),
|
||||
path_in_repo="README.md",
|
||||
repo_id=hub_repo_id,
|
||||
repo_type="model",
|
||||
commit_message="Add model card for migrated model",
|
||||
)
|
||||
print("Model card pushed to hub")
|
||||
|
||||
print("\nMigration complete!")
|
||||
print(f"Migrated model saved to: {output_dir}")
|
||||
if args.push_to_hub:
|
||||
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,84 +1,179 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
from .converters import to_tensor
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
|
||||
"""Convert numpy arrays and other types to torch tensors."""
|
||||
tensor_stats: dict[str, dict[str, Tensor]] = {}
|
||||
for key, sub in stats.items():
|
||||
tensor_stats[key] = {}
|
||||
for stat_name, value in sub.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
tensor_val = torch.from_numpy(value.astype(np.float32))
|
||||
elif isinstance(value, torch.Tensor):
|
||||
tensor_val = value.to(dtype=torch.float32)
|
||||
elif isinstance(value, (int, float, list, tuple)):
|
||||
tensor_val = torch.tensor(value, dtype=torch.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}")
|
||||
tensor_stats[key][stat_name] = tensor_val
|
||||
return tensor_stats
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NormalizationMixin:
|
||||
"""
|
||||
A mixin class providing core functionality for normalization and unnormalization.
|
||||
@ProcessorStepRegistry.register(name="normalizer_processor")
|
||||
class NormalizerProcessor:
|
||||
"""Normalizes observations and actions in a single processor step.
|
||||
|
||||
This class manages normalization statistics, their conversion to tensors, device placement,
|
||||
and the application of normalization transformations. It is designed to be inherited by
|
||||
concrete ProcessorStep implementations.
|
||||
This processor handles normalization of both observation and action tensors
|
||||
using either mean/std normalization or min/max scaling to a [-1, 1] range.
|
||||
|
||||
For each tensor key in the stats dictionary, the processor will:
|
||||
- Use mean/std normalization if those statistics are provided: (x - mean) / std
|
||||
- Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1
|
||||
|
||||
The processor can be configured to normalize only specific keys by setting
|
||||
the normalize_keys parameter.
|
||||
"""
|
||||
|
||||
# Features and normalisation map are mandatory to match the design of normalize.py
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
|
||||
# Pre-computed statistics coming from dataset.meta.stats for instance.
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
device: torch.device | str | None = None
|
||||
|
||||
# Explicit subset of keys to normalise. If ``None`` every key (except
|
||||
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
|
||||
# membership checks O(1).
|
||||
normalize_keys: set[str] | None = None
|
||||
|
||||
eps: float = 1e-8
|
||||
normalize_observation_keys: set[str] | None = None
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
normalize_keys: set[str] | None = None,
|
||||
eps: float = 1e-8,
|
||||
) -> NormalizerProcessor:
|
||||
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
|
||||
|
||||
The features and norm_map parameters are mandatory to match the design
|
||||
pattern used in normalize.py.
|
||||
"""
|
||||
|
||||
return cls(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=dataset.meta.stats,
|
||||
normalize_keys=normalize_keys,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Robust JSON deserialization handling (guard empty maps)
|
||||
if self.features:
|
||||
first_val = next(iter(self.features.values()))
|
||||
if isinstance(first_val, dict):
|
||||
reconstructed = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed
|
||||
# Handle deserialization from JSON config
|
||||
if self.features and isinstance(list(self.features.values())[0], dict):
|
||||
# Features came from JSON - need to reconstruct PolicyFeature objects
|
||||
reconstructed_features = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed_features[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed_features
|
||||
|
||||
if self.norm_map:
|
||||
# if keys are strings (JSON), rebuild enum map
|
||||
if all(isinstance(k, str) for k in self.norm_map.keys()):
|
||||
reconstructed = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed
|
||||
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
|
||||
# norm_map came from JSON - need to reconstruct enum keys and values
|
||||
reconstructed_norm_map = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed_norm_map
|
||||
|
||||
# Convert stats to tensors and move to the target device once during initialization.
|
||||
# Convert statistics once so we avoid repeated numpy→Tensor conversions
|
||||
# during runtime.
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device)
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||
|
||||
def to(self, device: torch.device | str) -> _NormalizationMixin:
|
||||
"""Moves the processor's normalization stats to the specified device and returns self."""
|
||||
self.device = device
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device)
|
||||
return self
|
||||
# Ensure *normalize_keys* is a set for fast look-ups and compare by
|
||||
# value later when returning the configuration.
|
||||
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
|
||||
self.normalize_keys = set(self.normalize_keys)
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat: dict[str, Tensor] = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU
|
||||
return flat
|
||||
def _normalize_obs(self, observation):
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
def load_state_dict(self, state: dict[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
# Load to the processor's configured device.
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||
dtype=torch.float32, device=self.device
|
||||
# Decide which keys should be normalised for this call.
|
||||
if self.normalize_keys is not None:
|
||||
keys_to_norm = self.normalize_keys
|
||||
else:
|
||||
# Use feature map to skip action keys.
|
||||
keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION}
|
||||
|
||||
processed = dict(observation)
|
||||
for key in keys_to_norm:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
continue
|
||||
|
||||
orig_val = processed[key]
|
||||
tensor = (
|
||||
orig_val.to(dtype=torch.float32)
|
||||
if isinstance(orig_val, torch.Tensor)
|
||||
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = (tensor - mean) / (std + self.eps)
|
||||
elif "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
return processed
|
||||
|
||||
def _normalize_action(self, action):
|
||||
if action is None or "action" not in self._tensor_stats:
|
||||
return action
|
||||
|
||||
tensor = (
|
||||
action.to(dtype=torch.float32)
|
||||
if isinstance(action, torch.Tensor)
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return (tensor - mean) / (std + self.eps)
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._normalize_action(transition.get(TransitionKey.ACTION))
|
||||
|
||||
# Create a new transition with normalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
config = {
|
||||
@@ -88,185 +183,149 @@ class _NormalizationMixin:
|
||||
},
|
||||
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
|
||||
}
|
||||
if self.normalize_observation_keys is not None:
|
||||
config["normalize_observation_keys"] = sorted(self.normalize_observation_keys)
|
||||
if self.normalize_keys is not None:
|
||||
# Serialise as a list for YAML / JSON friendliness
|
||||
config["normalize_keys"] = sorted(self.normalize_keys)
|
||||
return config
|
||||
|
||||
def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]:
|
||||
new_observation = dict(observation)
|
||||
for key, feature in self.features.items():
|
||||
if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys:
|
||||
continue
|
||||
if feature.type != FeatureType.ACTION and key in new_observation:
|
||||
tensor = torch.as_tensor(new_observation[key], dtype=torch.float32)
|
||||
new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse)
|
||||
return new_observation
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def _normalize_action(self, action: Any, inverse: bool) -> Tensor:
|
||||
tensor = torch.as_tensor(action, dtype=torch.float32)
|
||||
processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse)
|
||||
return processed_action
|
||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||
|
||||
def _apply_transform(
|
||||
self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False
|
||||
) -> Tensor:
|
||||
"""Core logic to apply normalization or unnormalization."""
|
||||
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
|
||||
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
|
||||
return tensor
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX):
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
# Ensure input tensor is on the same device as the stats.
|
||||
if self.device and tensor.device != self.device:
|
||||
tensor = tensor.to(self.device)
|
||||
|
||||
# For Accelerate compatibility: move stats to match input tensor device
|
||||
input_device = tensor.device
|
||||
stats = self._tensor_stats[key]
|
||||
tensor = tensor.to(dtype=torch.float32)
|
||||
|
||||
# Move stats to input device if needed
|
||||
stats_device = next(iter(stats.values())).device
|
||||
if stats_device != input_device:
|
||||
stats = to_tensor({key: self._tensor_stats[key]}, device=input_device)[key]
|
||||
|
||||
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
# Avoid division by zero by adding a small epsilon.
|
||||
denom = std + self.eps
|
||||
if inverse:
|
||||
return tensor * std + mean
|
||||
return (tensor - mean) / denom
|
||||
|
||||
if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
denom = max_val - min_val
|
||||
# When min_val == max_val, substitute the denominator with a small epsilon
|
||||
# to prevent division by zero. This consistently maps an input equal to
|
||||
# min_val to -1, ensuring a stable transformation.
|
||||
denom = torch.where(
|
||||
denom == 0, torch.tensor(self.eps, device=input_device, dtype=torch.float32), denom
|
||||
)
|
||||
if inverse:
|
||||
# Map from [-1, 1] back to [min, max]
|
||||
return (tensor + 1) / 2 * denom + min_val
|
||||
# Map from [min, max] to [-1, 1]
|
||||
return 2 * (tensor - min_val) / denom - 1
|
||||
|
||||
# If necessary stats are missing, return input unchanged.
|
||||
return tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="normalizer_processor")
|
||||
class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
||||
"""
|
||||
A processor that applies normalization to observations and actions in a transition.
|
||||
|
||||
This class directly implements the normalization logic for both observation and action
|
||||
components of an `EnvTransition`, using statistics (mean/std or min/max) provided at
|
||||
initialization.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
normalize_observation_keys: set[str] | None = None,
|
||||
eps: float = 1e-8,
|
||||
device: torch.device | str | None = None,
|
||||
) -> NormalizerProcessorStep:
|
||||
return cls(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=dataset.meta.stats,
|
||||
normalize_observation_keys=normalize_observation_keys,
|
||||
eps=eps,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Handle observation normalization.
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(
|
||||
observation, inverse=False
|
||||
)
|
||||
|
||||
# Handle action normalization.
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False)
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||
class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
||||
"""
|
||||
A processor that applies unnormalization (the inverse of normalization) to
|
||||
observations and actions in a transition.
|
||||
class UnnormalizerProcessor:
|
||||
"""Inverse normalisation for observations and actions.
|
||||
|
||||
This is typically used to transform actions from a normalized policy output back into
|
||||
the original scale for execution in an environment.
|
||||
Exactly mirrors :class:`NormalizerProcessor` but applies the inverse
|
||||
transform.
|
||||
"""
|
||||
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
device: torch.device | str | None = None,
|
||||
) -> UnnormalizerProcessorStep:
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device)
|
||||
) -> UnnormalizerProcessor:
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats)
|
||||
|
||||
def __post_init__(self):
|
||||
# Handle deserialization from JSON config
|
||||
if self.features and isinstance(list(self.features.values())[0], dict):
|
||||
# Features came from JSON - need to reconstruct PolicyFeature objects
|
||||
reconstructed_features = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed_features[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed_features
|
||||
|
||||
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
|
||||
# norm_map came from JSON - need to reconstruct enum keys and values
|
||||
reconstructed_norm_map = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed_norm_map
|
||||
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||
|
||||
def _unnormalize_obs(self, observation):
|
||||
if observation is None:
|
||||
return None
|
||||
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
|
||||
processed = dict(observation)
|
||||
for key in keys:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
continue
|
||||
orig_val = processed[key]
|
||||
tensor = (
|
||||
orig_val.to(dtype=torch.float32)
|
||||
if isinstance(orig_val, torch.Tensor)
|
||||
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = tensor * std + mean
|
||||
elif "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
return processed
|
||||
|
||||
def _unnormalize_action(self, action):
|
||||
if action is None or "action" not in self._tensor_stats:
|
||||
return action
|
||||
tensor = (
|
||||
action.to(dtype=torch.float32)
|
||||
if isinstance(action, torch.Tensor)
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return tensor * std + mean
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
|
||||
|
||||
# Create a new transition with unnormalized values
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Handle observation unnormalization.
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(observation, inverse=True)
|
||||
|
||||
# Handle action unnormalization.
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True)
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"features": {
|
||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
|
||||
},
|
||||
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
def hotswap_stats(
|
||||
policy_processor: PolicyProcessorPipeline, stats: dict[str, dict[str, Any]]
|
||||
) -> PolicyProcessorPipeline:
|
||||
"""
|
||||
Replaces normalization statistics in a PolicyProcessor pipeline.
|
||||
|
||||
This function creates a deep copy of the provided `PolicyProcessorPipeline` and updates the
|
||||
statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` steps within it.
|
||||
It's useful for adapting a trained policy to a new environment or dataset with
|
||||
different data distributions.
|
||||
"""
|
||||
rp = deepcopy(policy_processor)
|
||||
for step in rp.steps:
|
||||
if isinstance(step, _NormalizationMixin):
|
||||
step.stats = stats
|
||||
# Re-initialize tensor_stats on the correct device.
|
||||
step._tensor_stats = to_tensor(stats, device=step.device)
|
||||
return rp
|
||||
|
||||
@@ -22,13 +22,12 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="observation_processor")
|
||||
class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
class VanillaObservationProcessor(ObservationProcessor):
|
||||
"""
|
||||
Processes environment observations into the LeRobot format by handling both images and states.
|
||||
|
||||
@@ -107,8 +106,9 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms feature keys to a standardized contract.
|
||||
|
||||
This method handles several renaming patterns:
|
||||
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
|
||||
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
|
||||
|
||||
+425
-279
File diff suppressed because it is too large
Load Diff
@@ -13,18 +13,19 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="rename_processor")
|
||||
class RenameProcessorStep(ObservationProcessorStep):
|
||||
class RenameProcessor(ObservationProcessor):
|
||||
"""Rename processor that renames keys in the observation."""
|
||||
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
@@ -42,20 +43,9 @@ class RenameProcessorStep(ObservationProcessorStep):
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"rename_map": self.rename_map}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
- Each key in the observation that appears in `rename_map` is renamed to its value.
|
||||
- Keys not in `rename_map` remain unchanged.
|
||||
"""
|
||||
return {self.rename_map.get(k, k): v for k, v in features.items()}
|
||||
|
||||
|
||||
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
|
||||
"""Rename keys in the stats dictionary according to rename_map (defensive copy)."""
|
||||
if not stats:
|
||||
return {}
|
||||
renamed: dict[str, dict[str, Any]] = {}
|
||||
for old_key, sub_stats in stats.items():
|
||||
new_key = rename_map.get(old_key, old_key)
|
||||
renamed[new_key] = deepcopy(sub_stats) if sub_stats is not None else {}
|
||||
return renamed
|
||||
|
||||
@@ -1,246 +0,0 @@
|
||||
"""
|
||||
Tokenizer processor for handling text tokenization in robot transitions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
else:
|
||||
AutoTokenizer = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="tokenizer_processor")
|
||||
class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
|
||||
|
||||
This processor handles tokenization of task strings found in the complementary_data
|
||||
using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions
|
||||
to the observation data for model processing while preserving the original task string.
|
||||
|
||||
The processor supports both single strings and lists of strings as task inputs.
|
||||
|
||||
Args:
|
||||
tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub
|
||||
(e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used
|
||||
with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored.
|
||||
tokenizer: A tokenizer object (e.g., from transformers library) that implements
|
||||
the __call__ method. If provided, tokenizer_name is ignored. This parameter
|
||||
is not serialized and must be provided via overrides when loading.
|
||||
max_length: Maximum sequence length for tokenization. Defaults to 512.
|
||||
task_key: Key in complementary_data containing the task text. Defaults to "task".
|
||||
padding: Padding strategy for tokenization. Defaults to "max_length".
|
||||
truncation: Whether to truncate sequences longer than max_length. Defaults to True.
|
||||
|
||||
Examples:
|
||||
Using tokenizer name (auto-loaded):
|
||||
```python
|
||||
processor = TokenizerProcessorStep(tokenizer_name="bert-base-uncased", max_length=128)
|
||||
```
|
||||
|
||||
Using custom tokenizer object:
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
processor = TokenizerProcessorStep(tokenizer=custom_tokenizer, max_length=128)
|
||||
```
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
|
||||
max_length: int = 512
|
||||
task_key: str = "task"
|
||||
padding_side: str = "right"
|
||||
padding: str = "max_length"
|
||||
truncation: bool = True
|
||||
|
||||
# Internal tokenizer instance (not serialized)
|
||||
input_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessorStep."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self.input_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoTokenizer is None:
|
||||
raise ImportError("AutoTokenizer is not available")
|
||||
self.input_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||
)
|
||||
|
||||
def get_task(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""Extract and normalize task from complementary data.
|
||||
|
||||
Args:
|
||||
transition: Input transition containing complementary_data.
|
||||
|
||||
Returns:
|
||||
List of task strings if task is present, None otherwise.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
if self.task_key not in complementary_data:
|
||||
return None
|
||||
|
||||
task = complementary_data[self.task_key]
|
||||
if task is None:
|
||||
return None
|
||||
|
||||
# Convert to list of strings
|
||||
if isinstance(task, str):
|
||||
return [task]
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
return task
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation):
|
||||
"""Process the transition by tokenizing the task text.
|
||||
|
||||
Args:
|
||||
transition: Input transition containing complementary_data with task text.
|
||||
|
||||
Returns:
|
||||
Modified transition with tokenized task added to observation.
|
||||
|
||||
Raises:
|
||||
ValueError: If tokenizer initialization failed.
|
||||
"""
|
||||
task = self.get_task(self.transition)
|
||||
if task is None:
|
||||
return observation
|
||||
|
||||
# Tokenize the task (creates CPU tensors)
|
||||
tokenized_prompt = self._tokenize_text(task)
|
||||
|
||||
# Detect device from existing tensors in the transition
|
||||
target_device = self._detect_device(self.transition)
|
||||
|
||||
# Move tokenized tensors to match the device of other data
|
||||
if target_device is not None:
|
||||
tokenized_prompt = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_prompt.items()
|
||||
}
|
||||
|
||||
# Get or create observation dict
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Add tokenized data to observation
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
"""Detect device from existing tensors in the transition.
|
||||
|
||||
This allows the tokenized tensors to match the device of other data,
|
||||
which is especially important for multi-GPU training with Accelerate.
|
||||
|
||||
Args:
|
||||
transition: The transition to search for existing tensors.
|
||||
|
||||
Returns:
|
||||
The device of the first tensor found, or None if no tensors exist.
|
||||
"""
|
||||
# Check observation tensors first (most likely to exist)
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation:
|
||||
for value in observation.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
|
||||
# Check action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if isinstance(action, torch.Tensor):
|
||||
return action.device
|
||||
|
||||
return None # No tensors found, keep on CPU
|
||||
|
||||
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
|
||||
"""Tokenize text using the configured tokenizer.
|
||||
|
||||
Args:
|
||||
text: Text string or list of strings to tokenize.
|
||||
|
||||
Returns:
|
||||
Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'.
|
||||
"""
|
||||
return self.input_tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
truncation=self.truncation,
|
||||
padding=self.padding,
|
||||
padding_side=self.padding_side,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization.
|
||||
|
||||
Note: Only tokenizer_name is saved, not the tokenizer object itself.
|
||||
When loading, provide the tokenizer via overrides if needed.
|
||||
"""
|
||||
config = {
|
||||
"max_length": self.max_length,
|
||||
"task_key": self.task_key,
|
||||
"padding_side": self.padding_side,
|
||||
"padding": self.padding,
|
||||
"truncation": self.truncation,
|
||||
}
|
||||
|
||||
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
|
||||
# TODO(steven): Consider saving the name of the _tokenizer if it was loaded
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the feature contract.
|
||||
|
||||
Args:
|
||||
features: Input feature dictionary.
|
||||
|
||||
Returns:
|
||||
Updated feature dictionary with tokenized task features added.
|
||||
"""
|
||||
# Add features for tokenized output if they don't exist
|
||||
# Standard tokenizer output includes tokens and attention_mask
|
||||
|
||||
if OBS_LANGUAGE_TOKENS not in features:
|
||||
features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||
|
||||
if OBS_LANGUAGE_ATTENTION_MASK not in features:
|
||||
features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
return features
|
||||
+38
-163
@@ -59,7 +59,7 @@ lerobot-record \
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
@@ -72,23 +72,10 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
RobotProcessorPipeline,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
action_to_transition,
|
||||
observation_to_transition,
|
||||
transition_to_dataset_frame,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -162,8 +149,6 @@ class DatasetRecordConfig:
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.single_task is None:
|
||||
@@ -202,36 +187,6 @@ class RecordConfig:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
""" --------------- record_loop() data flow --------------------------
|
||||
[ Robot ]
|
||||
V
|
||||
[ robot.get_observation() ] ---> raw_obs
|
||||
V
|
||||
[ robot_observation_processor ] ---> obs_transition
|
||||
V
|
||||
.-----( ACTION LOGIC )------------------.
|
||||
V V
|
||||
[ From Teleoperator ] [ From Policy ]
|
||||
| |
|
||||
| [teleop.get_action] -> raw_action | [predict_action]
|
||||
| | | |
|
||||
| V | V
|
||||
| [teleop_action_processor] | |
|
||||
| | | |
|
||||
'---> teleop_transition '---> policy_transition
|
||||
| |
|
||||
'-------------------------.-------------'
|
||||
V
|
||||
[ robot_action_processor ] --> robot_action_to_send
|
||||
V
|
||||
[ robot.send_action() ] -- (Robot Executes)
|
||||
V
|
||||
( Transitions are merged & added to Dataset )
|
||||
V
|
||||
( Rerun Log / Loop Wait )
|
||||
"""
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def record_loop(
|
||||
robot: Robot,
|
||||
@@ -240,38 +195,28 @@ def record_loop(
|
||||
dataset: LeRobotDataset | None = None,
|
||||
teleop: Teleoperator | list[Teleoperator] | None = None,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
preprocessor: PolicyProcessorPipeline | None = None,
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
control_time_s: int | None = None,
|
||||
teleop_action_processor: RobotProcessorPipeline | None = None, # runs after teleop
|
||||
robot_action_processor: RobotProcessorPipeline | None = None, # runs before robot
|
||||
robot_observation_processor: RobotProcessorPipeline | None = None, # runs after robot
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
):
|
||||
teleop_action_processor = teleop_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=lambda tr: tr
|
||||
)
|
||||
robot_action_processor = robot_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()], to_transition=lambda tr: tr, to_output=transition_to_robot_action
|
||||
)
|
||||
robot_observation_processor = robot_observation_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
||||
|
||||
teleop_arm = teleop_keyboard = None
|
||||
if isinstance(teleop, list): # For LeKiwi
|
||||
if isinstance(teleop, list):
|
||||
teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None)
|
||||
teleop_arm = next(
|
||||
(
|
||||
t
|
||||
for t in teleop
|
||||
if isinstance(t, (so100_leader.SO100Leader, so101_leader.SO101Leader, koch_leader.KochLeader))
|
||||
if isinstance(
|
||||
t,
|
||||
(
|
||||
so100_leader.SO100Leader,
|
||||
so101_leader.SO101Leader,
|
||||
koch_leader.KochLeader,
|
||||
),
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -281,20 +226,9 @@ def record_loop(
|
||||
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
|
||||
)
|
||||
|
||||
# Reset policy and processor if they are provided
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
# if policy is given it needs cleaning up
|
||||
if policy is not None:
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
# Reset custom pipelines
|
||||
teleop_action_processor.reset()
|
||||
robot_action_processor.reset()
|
||||
robot_observation_processor.reset()
|
||||
|
||||
policy_transition = None
|
||||
teleop_transition = None
|
||||
obs_transition = None
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
@@ -305,88 +239,51 @@ def record_loop(
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
observation = robot.get_observation()
|
||||
|
||||
# Applies a pipeline to the raw robot observation, default is IdentityProcessor
|
||||
obs_transition = robot_observation_processor(obs)
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
if dataset is not None:
|
||||
observation_frame = transition_to_dataset_frame(
|
||||
obs_transition, dataset.features
|
||||
) # Convert the observation to the dataset format
|
||||
if policy is not None or dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||
|
||||
if policy is not None:
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
observation_frame,
|
||||
policy,
|
||||
get_safe_torch_device(policy.config.device),
|
||||
policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
|
||||
action_names = dataset.features["action"]["names"]
|
||||
policy_action = {f"action.{name}": float(action_values[i]) for i, name in enumerate(action_names)}
|
||||
policy_transition = {
|
||||
TransitionKey.ACTION: policy_action,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
elif isinstance(teleop, Teleoperator):
|
||||
act = teleop.get_action()
|
||||
|
||||
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
|
||||
teleop_transition = teleop_action_processor(act)
|
||||
|
||||
elif isinstance(teleop, list):
|
||||
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
action = teleop.get_action()
|
||||
elif policy is None and isinstance(teleop, list):
|
||||
# TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline)
|
||||
arm_action = teleop_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
|
||||
keyboard_action = teleop_keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
teleop_transition = teleop_action_processor(act)
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation. "
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
"This is likely to happen when resetting the environment without a teleop device."
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
continue
|
||||
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
# IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action()
|
||||
if policy is not None and policy_transition is not None:
|
||||
robot_action_to_send = robot_action_processor(policy_transition)
|
||||
else:
|
||||
robot_action_to_send = robot_action_processor(teleop_transition)
|
||||
|
||||
# Send action to robot
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
# TODO(pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_ = robot.send_action(robot_action_to_send)
|
||||
# so action actually sent is saved in the dataset.
|
||||
sent_action = robot.send_action(action)
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
# If transition_to_dataset_frame is provided, use it to merge the transitions.
|
||||
merged = []
|
||||
if obs_transition is not None: # The observation from the robot
|
||||
merged.append(obs_transition)
|
||||
if teleop_transition is not None: # The action from teleop
|
||||
merged.append(teleop_transition)
|
||||
if policy_transition is not None: # The action from policy
|
||||
merged.append(policy_transition)
|
||||
frame = transition_to_dataset_frame(
|
||||
merged if len(merged) > 1 else merged[0], dataset.features
|
||||
) # Convert the observation to the dataset format
|
||||
dataset.add_frame(frame, task=single_task)
|
||||
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
||||
frame = {**observation_frame, **action_frame, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_data:
|
||||
log_rerun_data([obs_transition, teleop_transition or policy_transition])
|
||||
log_rerun_data(observation, action)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
@@ -406,15 +303,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
|
||||
|
||||
# Add next.* features that are generated during recording
|
||||
transition_features = {
|
||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
"next.truncated": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
dataset_features = {**action_features, **obs_features, **transition_features}
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
@@ -446,18 +335,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
if cfg.policy is not None:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
@@ -475,8 +352,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
fps=cfg.dataset.fps,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
|
||||
+11
-22
@@ -45,10 +45,9 @@ from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
from lerobot.configs import parser
|
||||
import draccus
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import action_to_transition, transition_to_robot_action
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -56,6 +55,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
reachy2,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
@@ -84,32 +84,24 @@ class ReplayConfig:
|
||||
dataset: DatasetReplayConfig
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# Optional processor for actions before sending to robot
|
||||
robot_action_processor: RobotProcessorPipeline | None = None
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
@draccus.wrap()
|
||||
def replay(cfg: ReplayConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Initialize robot action processor with default if not provided
|
||||
robot_action_processor = cfg.robot_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=action_to_transition,
|
||||
to_output=transition_to_robot_action, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Reset processor
|
||||
robot_action_processor.reset()
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode)
|
||||
actions = episode_frames.select_columns("action")
|
||||
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
for idx in range(len(episode_frames)):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx]["action"]
|
||||
@@ -117,10 +109,7 @@ def replay(cfg: ReplayConfig):
|
||||
for i, name in enumerate(dataset.features["action"]["names"]):
|
||||
action[name] = action_array[i]
|
||||
|
||||
# Process action through robot action processor
|
||||
processed_action = robot_action_processor(action)
|
||||
|
||||
robot.send_action(processed_action) # type: ignore[arg-type]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / dataset.fps - dt_s)
|
||||
|
||||
@@ -29,10 +29,10 @@ class BiSO100FollowerConfig(RobotConfig):
|
||||
|
||||
# Optional
|
||||
left_arm_disable_torque_on_disconnect: bool = True
|
||||
left_arm_max_relative_target: int | None = None
|
||||
left_arm_max_relative_target: float | dict[str, float] | None = None
|
||||
left_arm_use_degrees: bool = False
|
||||
right_arm_disable_torque_on_disconnect: bool = True
|
||||
right_arm_max_relative_target: int | None = None
|
||||
right_arm_max_relative_target: float | dict[str, float] | None = None
|
||||
right_arm_use_degrees: bool = False
|
||||
|
||||
# cameras (shared between both arms)
|
||||
|
||||
@@ -44,8 +44,8 @@ class HopeJrArmConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -28,9 +28,9 @@ class KochFollowerConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -110,6 +110,7 @@ class KochFollower(Robot):
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.bus.disable_torque()
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
@@ -120,7 +121,6 @@ class KochFollower(Robot):
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
logger.info(f"\nRunning calibration of {self}")
|
||||
self.bus.disable_torque()
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
|
||||
@@ -39,9 +39,9 @@ class LeKiwiConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||
|
||||
|
||||
+10
-21
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,23 +14,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
class PhoneOS(Enum):
|
||||
ANDROID = "android"
|
||||
IOS = "ios"
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("phone")
|
||||
@dataclass
|
||||
class PhoneConfig(TeleoperatorConfig):
|
||||
phone_os: PhoneOS = PhoneOS.IOS
|
||||
camera_offset = np.array(
|
||||
[0.0, -0.02, 0.04]
|
||||
) # iPhone 14 Pro camera is 2cm off center and 4cm above center
|
||||
from .configuration_reachy2 import Reachy2RobotConfig
|
||||
from .robot_reachy2 import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
REACHY2_NECK_JOINTS,
|
||||
REACHY2_R_ARM_JOINTS,
|
||||
REACHY2_VEL,
|
||||
Reachy2Robot,
|
||||
)
|
||||
@@ -0,0 +1,107 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.cameras.configs import ColorMode
|
||||
from lerobot.cameras.reachy2_camera import Reachy2CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("reachy2")
|
||||
@dataclass
|
||||
class Reachy2RobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors.
|
||||
max_relative_target: float | None = None
|
||||
|
||||
# IP address of the Reachy 2 robot
|
||||
ip_address: str | None = "localhost"
|
||||
|
||||
# If True, turn_off_smoothly() will be sent to the robot before disconnecting.
|
||||
disable_torque_on_disconnect: bool = False
|
||||
|
||||
# Tag for external commands control
|
||||
# Set to True if you use an external commands system to control the robot,
|
||||
# such as the official teleoperation application: https://github.com/pollen-robotics/Reachy2Teleoperation
|
||||
# If True, robot.send_action() will not send commands to the robot.
|
||||
use_external_commands: bool = False
|
||||
|
||||
# Robot parts
|
||||
# Set to False to not add the corresponding joints part to the robot list of joints.
|
||||
# By default, all parts are set to True.
|
||||
with_mobile_base: bool = True
|
||||
with_l_arm: bool = True
|
||||
with_r_arm: bool = True
|
||||
with_neck: bool = True
|
||||
with_antennas: bool = True
|
||||
|
||||
# Robot cameras
|
||||
# Set to True if you want to use the corresponding cameras in the observations.
|
||||
# By default, only the teleop cameras are used.
|
||||
with_left_teleop_camera: bool = True
|
||||
with_right_teleop_camera: bool = True
|
||||
with_torso_camera: bool = False
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Add cameras with same ip_address as the robot
|
||||
if self.with_left_teleop_camera:
|
||||
self.cameras["teleop_left"] = Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
image_type="left",
|
||||
ip_address=self.ip_address,
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
if self.with_right_teleop_camera:
|
||||
self.cameras["teleop_right"] = Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
image_type="right",
|
||||
ip_address=self.ip_address,
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
if self.with_torso_camera:
|
||||
self.cameras["torso_rgb"] = Reachy2CameraConfig(
|
||||
name="depth",
|
||||
image_type="rgb",
|
||||
ip_address=self.ip_address,
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
if not (
|
||||
self.with_mobile_base
|
||||
or self.with_l_arm
|
||||
or self.with_r_arm
|
||||
or self.with_neck
|
||||
or self.with_antennas
|
||||
):
|
||||
raise ValueError(
|
||||
"No Reachy2Robot part used.\n"
|
||||
"At least one part of the robot must be set to True "
|
||||
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
|
||||
)
|
||||
@@ -0,0 +1,230 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from reachy2_sdk import ReachySDK
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
from .configuration_reachy2 import Reachy2RobotConfig
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_NECK_JOINTS = {
|
||||
"neck_yaw.pos": "head.neck.yaw",
|
||||
"neck_pitch.pos": "head.neck.pitch",
|
||||
"neck_roll.pos": "head.neck.roll",
|
||||
}
|
||||
|
||||
REACHY2_ANTENNAS_JOINTS = {
|
||||
"l_antenna.pos": "head.l_antenna",
|
||||
"r_antenna.pos": "head.r_antenna",
|
||||
}
|
||||
|
||||
REACHY2_R_ARM_JOINTS = {
|
||||
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
|
||||
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
|
||||
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
|
||||
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
|
||||
"r_wrist_roll.pos": "r_arm.wrist.roll",
|
||||
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
|
||||
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
|
||||
"r_gripper.pos": "r_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_L_ARM_JOINTS = {
|
||||
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
|
||||
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
|
||||
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
|
||||
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
|
||||
"l_wrist_roll.pos": "l_arm.wrist.roll",
|
||||
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
|
||||
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
|
||||
"l_gripper.pos": "l_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_VEL = {
|
||||
"mobile_base.vx": "vx",
|
||||
"mobile_base.vy": "vy",
|
||||
"mobile_base.vtheta": "vtheta",
|
||||
}
|
||||
|
||||
|
||||
class Reachy2Robot(Robot):
|
||||
"""
|
||||
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
|
||||
"""
|
||||
|
||||
config_class = Reachy2RobotConfig
|
||||
name = "reachy2"
|
||||
|
||||
def __init__(self, config: Reachy2RobotConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
self.robot_type = self.config.type
|
||||
self.use_external_commands = self.config.use_external_commands
|
||||
|
||||
self.reachy: None | ReachySDK = None
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.logs: dict[str, float] = {}
|
||||
|
||||
self.joints_dict: dict[str, str] = self._generate_joints_dict()
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict[str, Any]:
|
||||
return {**self.motors_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self.motors_features
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict[str, tuple[int | None, int | None, int]]:
|
||||
return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras}
|
||||
|
||||
@property
|
||||
def motors_features(self) -> dict[str, type]:
|
||||
if self.config.with_mobile_base:
|
||||
return {
|
||||
**dict.fromkeys(
|
||||
self.joints_dict.keys(),
|
||||
float,
|
||||
),
|
||||
**dict.fromkeys(
|
||||
REACHY2_VEL.keys(),
|
||||
float,
|
||||
),
|
||||
}
|
||||
else:
|
||||
return dict.fromkeys(self.joints_dict.keys(), float)
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.reachy.is_connected() if self.reachy is not None else False
|
||||
|
||||
def connect(self, calibrate: bool = False) -> None:
|
||||
self.reachy = ReachySDK(self.config.ip_address)
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.configure()
|
||||
|
||||
def configure(self) -> None:
|
||||
if self.reachy is not None:
|
||||
self.reachy.turn_on()
|
||||
self.reachy.reset_default_limits()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
pass
|
||||
|
||||
def _generate_joints_dict(self) -> dict[str, str]:
|
||||
joints = {}
|
||||
if self.config.with_neck:
|
||||
joints.update(REACHY2_NECK_JOINTS)
|
||||
if self.config.with_l_arm:
|
||||
joints.update(REACHY2_L_ARM_JOINTS)
|
||||
if self.config.with_r_arm:
|
||||
joints.update(REACHY2_R_ARM_JOINTS)
|
||||
if self.config.with_antennas:
|
||||
joints.update(REACHY2_ANTENNAS_JOINTS)
|
||||
return joints
|
||||
|
||||
def _get_state(self) -> dict[str, float]:
|
||||
if self.reachy is not None:
|
||||
pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
|
||||
if not self.config.with_mobile_base:
|
||||
return pos_dict
|
||||
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
|
||||
return {**pos_dict, **vel_dict}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def get_observation(self) -> dict[str, np.ndarray]:
|
||||
obs_dict: dict[str, Any] = {}
|
||||
|
||||
# Read Reachy 2 state
|
||||
before_read_t = time.perf_counter()
|
||||
obs_dict.update(self._get_state())
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
if self.reachy is not None:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
before_write_t = time.perf_counter()
|
||||
|
||||
vel = {}
|
||||
goal_pos = {}
|
||||
for key, val in action.items():
|
||||
if key not in self.joints_dict:
|
||||
if key not in REACHY2_VEL:
|
||||
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
|
||||
else:
|
||||
vel[REACHY2_VEL[key]] = float(val)
|
||||
else:
|
||||
if not self.use_external_commands and self.config.max_relative_target is not None:
|
||||
goal_pos[key] = float(val)
|
||||
goal_present_pos = {
|
||||
key: (
|
||||
goal_pos[key],
|
||||
self.reachy.joints[self.joints_dict[key]].present_position,
|
||||
)
|
||||
}
|
||||
safe_goal_pos = ensure_safe_goal_position(
|
||||
goal_present_pos, float(self.config.max_relative_target)
|
||||
)
|
||||
val = safe_goal_pos[key]
|
||||
self.reachy.joints[self.joints_dict[key]].goal_position = float(val)
|
||||
|
||||
if self.config.with_mobile_base:
|
||||
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
|
||||
|
||||
# We don't send the goal positions if we control Reachy 2 externally
|
||||
if not self.use_external_commands:
|
||||
self.reachy.send_goal_positions()
|
||||
if self.config.with_mobile_base:
|
||||
self.reachy.mobile_base.send_speed_command()
|
||||
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
return action
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.reachy is not None:
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
if self.config.disable_torque_on_disconnect:
|
||||
self.reachy.turn_off_smoothly()
|
||||
self.reachy.disconnect()
|
||||
@@ -14,5 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_so100_follower import SO100FollowerConfig
|
||||
from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig
|
||||
from .so100_follower import SO100Follower
|
||||
from .so100_follower_end_effector import SO100FollowerEndEffector
|
||||
|
||||
@@ -30,12 +30,44 @@ class SO100FollowerConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so100_follower_end_effector")
|
||||
@dataclass
|
||||
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
|
||||
"""Configuration for the SO100FollowerEndEffector robot."""
|
||||
|
||||
# Path to URDF file for kinematics
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
urdf_path: str | None = None
|
||||
|
||||
# End-effector frame name in the URDF
|
||||
target_frame_name: str = "gripper_frame_link"
|
||||
|
||||
# Default bounds for the end-effector position (in meters)
|
||||
end_effector_bounds: dict[str, list[float]] = field(
|
||||
default_factory=lambda: {
|
||||
"min": [-1.0, -1.0, -1.0], # min x, y, z
|
||||
"max": [1.0, 1.0, 1.0], # max x, y, z
|
||||
}
|
||||
)
|
||||
|
||||
max_gripper_pos: float = 50
|
||||
|
||||
end_effector_step_sizes: dict[str, float] = field(
|
||||
default_factory=lambda: {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,460 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
ActionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
EnvTransition,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("ee_reference_and_delta")
|
||||
@dataclass
|
||||
class EEReferenceAndDelta(ActionProcessorStep):
|
||||
"""
|
||||
Compute the desired end-effector pose from the target pose and the current pose.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
end_effector_step_sizes: dict
|
||||
motor_names: list[str]
|
||||
use_latched_reference: bool = (
|
||||
True # If True, latch reference on enable; if False, always use current pose
|
||||
)
|
||||
|
||||
reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
_prev_enabled: bool = field(default=False, init=False, repr=False)
|
||||
_command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def action(self, action):
|
||||
new_action = action.copy()
|
||||
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
if raw is None:
|
||||
raise ValueError(
|
||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||
)
|
||||
|
||||
if "reference_joint_positions" in comp:
|
||||
q = comp["reference_joint_positions"]
|
||||
else:
|
||||
q = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
|
||||
# Current pose from FK on measured joints
|
||||
t_curr = self.kinematics.forward_kinematics(q)
|
||||
|
||||
enabled = bool(new_action.pop(f"{ACTION}.enabled", 0))
|
||||
tx = float(new_action.pop(f"{ACTION}.target_x", 0.0))
|
||||
ty = float(new_action.pop(f"{ACTION}.target_y", 0.0))
|
||||
tz = float(new_action.pop(f"{ACTION}.target_z", 0.0))
|
||||
wx = float(new_action.pop(f"{ACTION}.target_wx", 0.0))
|
||||
wy = float(new_action.pop(f"{ACTION}.target_wy", 0.0))
|
||||
wz = float(new_action.pop(f"{ACTION}.target_wz", 0.0))
|
||||
|
||||
desired = None
|
||||
|
||||
if enabled:
|
||||
ref = t_curr
|
||||
if self.use_latched_reference:
|
||||
# Latched reference mode: latch reference at the rising edge
|
||||
if not self._prev_enabled or self.reference_ee_pose is None:
|
||||
self.reference_ee_pose = t_curr.copy()
|
||||
ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr
|
||||
|
||||
delta_p = np.array(
|
||||
[
|
||||
tx * self.end_effector_step_sizes["x"],
|
||||
ty * self.end_effector_step_sizes["y"],
|
||||
tz * self.end_effector_step_sizes["z"],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
desired = np.eye(4, dtype=float)
|
||||
desired[:3, :3] = ref[:3, :3] @ r_abs
|
||||
desired[:3, 3] = ref[:3, 3] + delta_p
|
||||
|
||||
self._command_when_disabled = desired.copy()
|
||||
else:
|
||||
# While disabled, keep sending the same command to avoid drift.
|
||||
if self._command_when_disabled is None:
|
||||
# If we've never had an enabled command yet, freeze current FK pose once.
|
||||
self._command_when_disabled = t_curr.copy()
|
||||
desired = self._command_when_disabled.copy()
|
||||
|
||||
# Write action fields
|
||||
pos = desired[:3, 3]
|
||||
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
|
||||
new_action[f"{ACTION}.ee.x"] = float(pos[0])
|
||||
new_action[f"{ACTION}.ee.y"] = float(pos[1])
|
||||
new_action[f"{ACTION}.ee.z"] = float(pos[2])
|
||||
new_action[f"{ACTION}.ee.wx"] = float(tw[0])
|
||||
new_action[f"{ACTION}.ee.wy"] = float(tw[1])
|
||||
new_action[f"{ACTION}.ee.wz"] = float(tw[2])
|
||||
|
||||
self._prev_enabled = enabled
|
||||
return new_action
|
||||
|
||||
def reset(self):
|
||||
self._prev_enabled = False
|
||||
self.reference_ee_pose = None
|
||||
self._command_when_disabled = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.enabled", None)
|
||||
features.pop(f"{ACTION}.target_x", None)
|
||||
features.pop(f"{ACTION}.target_y", None)
|
||||
features.pop(f"{ACTION}.target_z", None)
|
||||
features.pop(f"{ACTION}.target_wx", None)
|
||||
features.pop(f"{ACTION}.target_wy", None)
|
||||
features.pop(f"{ACTION}.target_wz", None)
|
||||
|
||||
features[f"{ACTION}.ee.x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("ee_bounds_and_safety")
|
||||
@dataclass
|
||||
class EEBoundsAndSafety(ActionProcessorStep):
|
||||
"""
|
||||
Clip the end-effector pose to the bounds and check for jumps.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
end_effector_bounds: dict
|
||||
max_ee_step_m: float = 0.05
|
||||
max_ee_twist_step_rad: float = 0.20
|
||||
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict) -> dict:
|
||||
x = act.get(f"{ACTION}.ee.x", None)
|
||||
y = act.get(f"{ACTION}.ee.y", None)
|
||||
z = act.get(f"{ACTION}.ee.z", None)
|
||||
wx = act.get(f"{ACTION}.ee.wx", None)
|
||||
wy = act.get(f"{ACTION}.ee.wy", None)
|
||||
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
raise ValueError(
|
||||
"Missing required end-effector pose components: x, y, z, wx, wy, wz must all be present in action"
|
||||
)
|
||||
|
||||
pos = np.array([x, y, z], dtype=float)
|
||||
twist = np.array([wx, wy, wz], dtype=float)
|
||||
|
||||
# Clip position
|
||||
pos = np.clip(pos, self.end_effector_bounds["min"], self.end_effector_bounds["max"])
|
||||
|
||||
# Check for jumps in position
|
||||
if self._last_pos is not None:
|
||||
dpos = pos - self._last_pos
|
||||
n = float(np.linalg.norm(dpos))
|
||||
if n > self.max_ee_step_m and n > 0:
|
||||
pos = self._last_pos + dpos * (self.max_ee_step_m / n)
|
||||
raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m")
|
||||
|
||||
self._last_pos = pos
|
||||
self._last_twist = twist
|
||||
|
||||
act[f"{ACTION}.ee.x"] = float(pos[0])
|
||||
act[f"{ACTION}.ee.y"] = float(pos[1])
|
||||
act[f"{ACTION}.ee.z"] = float(pos[2])
|
||||
act[f"{ACTION}.ee.wx"] = float(twist[0])
|
||||
act[f"{ACTION}.ee.wy"] = float(twist[1])
|
||||
act[f"{ACTION}.ee.wz"] = float(twist[2])
|
||||
return act
|
||||
|
||||
def reset(self):
|
||||
self._last_pos = None
|
||||
self._last_twist = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# check if features as f"{ACTION}.ee.{x,y,z,wx,wy,wz}"
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
|
||||
@dataclass
|
||||
class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
"""
|
||||
Compute the desired joint positions from the desired end-effector pose.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.joint_name_1.pos": float,
|
||||
"action.joint_name_2.pos": float,
|
||||
...
|
||||
"action.joint_name_n.pos": float,
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
initial_guess_current_joints: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
x = act.get(f"{ACTION}.ee.x", None)
|
||||
y = act.get(f"{ACTION}.ee.y", None)
|
||||
z = act.get(f"{ACTION}.ee.z", None)
|
||||
wx = act.get(f"{ACTION}.ee.wx", None)
|
||||
wy = act.get(f"{ACTION}.ee.wy", None)
|
||||
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
return new_transition
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
if raw is None:
|
||||
raise ValueError(
|
||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||
)
|
||||
|
||||
if self.initial_guess_current_joints: # Use current joints as initial guess
|
||||
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
else: # Use previous ik solution as initial guess
|
||||
if self.q_curr is None:
|
||||
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
|
||||
# Build desired 4x4 transform from pos + rotvec (twist)
|
||||
t_des = np.eye(4, dtype=float)
|
||||
t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
t_des[:3, 3] = [x, y, z]
|
||||
|
||||
# Compute inverse kinematics
|
||||
q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
|
||||
self.q_curr = q_target
|
||||
|
||||
new_act = dict(act)
|
||||
for i, name in enumerate(self.motor_names):
|
||||
if name == "gripper":
|
||||
# TODO(pepijn): Investigate if this is correct
|
||||
# Do we want an observation key in the action field?
|
||||
new_act[f"{ACTION}.gripper.pos"] = float(raw["gripper"])
|
||||
else:
|
||||
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
if not self.initial_guess_current_joints:
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[f"{ACTION}.gripper.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for name in self.motor_names:
|
||||
features[f"{ACTION}.{name}.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
|
||||
return features
|
||||
|
||||
def reset(self):
|
||||
self.q_curr = None
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("gripper_velocity_to_joint")
|
||||
@dataclass
|
||||
class GripperVelocityToJoint(ProcessorStep):
|
||||
"""
|
||||
Convert the gripper velocity to a joint velocity.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.gripper": float,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.gripper.pos": float,
|
||||
}
|
||||
"""
|
||||
|
||||
motor_names: list[str]
|
||||
speed_factor: float = 20.0
|
||||
clip_min: float = 0.0
|
||||
clip_max: float = 100.0
|
||||
discrete_gripper: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
obs = new_transition.get(TransitionKey.OBSERVATION) or {}
|
||||
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if f"{ACTION}.gripper" not in act:
|
||||
raise ValueError(f"Required action key '{ACTION}.gripper' not found in transition")
|
||||
|
||||
if "gripper" not in self.motor_names:
|
||||
raise ValueError(
|
||||
f"Required motor name 'gripper' not found in self.motor_names={self.motor_names}"
|
||||
)
|
||||
|
||||
if self.discrete_gripper:
|
||||
# Discrete gripper actions are in [0, 1, 2]
|
||||
# 0: open, 1: close, 2: stay
|
||||
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
||||
gripper_action = act.get(f"{ACTION}.gripper", 1.0)
|
||||
gripper_action = gripper_action - 1.0
|
||||
gripper_action *= self.clip_max
|
||||
act[f"{ACTION}.gripper"] = gripper_action
|
||||
|
||||
# Get current gripper position from complementary data
|
||||
raw = comp.get("raw_joint_positions") or {}
|
||||
curr_pos = float(raw.get("gripper"))
|
||||
|
||||
# Compute desired gripper velocity
|
||||
u = float(act.get(f"{ACTION}.gripper", 0.0))
|
||||
delta = u * float(self.speed_factor)
|
||||
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
|
||||
|
||||
new_act = dict(act)
|
||||
new_act[f"{ACTION}.gripper.pos"] = gripper_pos
|
||||
new_act.pop(f"{ACTION}.gripper", None)
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
|
||||
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
|
||||
new_transition[TransitionKey.OBSERVATION] = obs
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.gripper", None)
|
||||
features[f"{ACTION}.gripper.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{OBS_STATE}.gripper.pos"] = PolicyFeature(type=FeatureType.STATE, shape=(1,))
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee")
|
||||
@dataclass
|
||||
class ForwardKinematicsJointsToEE(ObservationProcessorStep):
|
||||
"""
|
||||
Compute the end-effector pose from the joint positions.
|
||||
|
||||
Input OBSERVATION keys:
|
||||
{
|
||||
"observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float,
|
||||
}
|
||||
|
||||
Output OBSERVATION keys:
|
||||
{
|
||||
"observation.state.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
|
||||
def observation(self, obs: dict) -> dict:
|
||||
if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names):
|
||||
raise ValueError(f"Missing required joint positions for motors: {self.motor_names}")
|
||||
|
||||
q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float)
|
||||
t = self.kinematics.forward_kinematics(q)
|
||||
pos = t[:3, 3]
|
||||
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
||||
|
||||
obs[f"{OBS_STATE}.ee.x"] = float(pos[0])
|
||||
obs[f"{OBS_STATE}.ee.y"] = float(pos[1])
|
||||
obs[f"{OBS_STATE}.ee.z"] = float(pos[2])
|
||||
obs[f"{OBS_STATE}.ee.wx"] = float(tw[0])
|
||||
obs[f"{OBS_STATE}.ee.wy"] = float(tw[1])
|
||||
obs[f"{OBS_STATE}.ee.wz"] = float(tw[2])
|
||||
return obs
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz"]:
|
||||
features[f"{OBS_STATE}.ee.{k}"] = PolicyFeature(type=FeatureType.STATE, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_robot_observation")
|
||||
@dataclass
|
||||
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Read the robot's current observation and insert it into the transition as complementary data.
|
||||
|
||||
- Joint positions are added under complementary_data["raw_joint_positions"] as a dict:
|
||||
{ "<motor_name>": <float position>, ... }
|
||||
"""
|
||||
|
||||
robot: Robot
|
||||
|
||||
def complementary_data(self, comp: dict | None) -> dict:
|
||||
new_comp = dict(comp)
|
||||
obs = (
|
||||
self.robot.get_observation()
|
||||
) # todo(steven): why not self.trtansition.get(TransitionKey.OBSERVATION)?
|
||||
|
||||
new_comp["raw_joint_positions"] = {
|
||||
k.removesuffix(".pos"): float(v)
|
||||
for k, v in obs.items()
|
||||
if isinstance(k, str) and k.endswith(".pos")
|
||||
}
|
||||
return new_comp
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user