mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-12 15:19:43 +00:00
Compare commits
33 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bbcf66bd82 | |||
| c378a325f0 | |||
| 90684a9690 | |||
| f59eb54f5c | |||
| 62e9849ffd | |||
| e3b572992e | |||
| 5b647e3bcb | |||
| ddfff054bc | |||
| 49918efbc1 | |||
| c5b5955c5a | |||
| ec40ccde0d | |||
| d2782cf66b | |||
| 9627765ce2 | |||
| 43d878a102 | |||
| ddba994d73 | |||
| a87d4c9a74 | |||
| 170c09e7f6 | |||
| 853cc70194 | |||
| ec63225dc1 | |||
| af1760f175 | |||
| 163df97c0c | |||
| cdd2bf1c4e | |||
| 1cba47da20 | |||
| 7359e18eb6 | |||
| 13010647bc | |||
| acbc14f60a | |||
| 2b59850f15 | |||
| 42e4b3d09e | |||
| 98bcda2d8b | |||
| a4178f385b | |||
| bd09b2153f | |||
| 1033680a57 | |||
| 7cf04a5ec3 |
@@ -31,11 +31,11 @@ env:
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
WARN_ISSUE_MESSAGE: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity (1 year). It will be closed if no further activity occurs.
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
Thank you for your contributions.
|
||||
WARN_PR_MESSAGE: >
|
||||
This PR has been automatically marked as stale because it has not had
|
||||
recent activity (1 year). It will be closed if no further activity occurs.
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
Thank you for your contributions.
|
||||
|
||||
jobs:
|
||||
|
||||
@@ -35,12 +35,13 @@ import torch
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from benchmarks.video.benchmark import TimeBenchmark
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import (
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.utils.benchmark import TimeBenchmark
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
|
||||
BASE_ENCODING = OrderedDict(
|
||||
[
|
||||
@@ -117,7 +118,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# We only save images from the first camera
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(
|
||||
|
||||
@@ -31,7 +31,7 @@ Then, spin up a policy server (in one terminal, or in a separate machine) specif
|
||||
You can spin up a policy server running:
|
||||
|
||||
```shell
|
||||
python src/lerobot/scripts/server/policy_server.py \
|
||||
python src/lerobot/async_inference/policy_server.py \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
```
|
||||
@@ -39,7 +39,7 @@ python src/lerobot/scripts/server/policy_server.py \
|
||||
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
|
||||
|
||||
```shell
|
||||
python src/lerobot/scripts/server/robot_client.py \
|
||||
python src/lerobot/async_inference/robot_client.py \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
@@ -122,8 +122,8 @@ python -m lerobot.scripts.server.policy_server \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
||||
from lerobot.scripts.server.policy_server import serve
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host="localhost",
|
||||
@@ -148,7 +148,7 @@ The `RobotClient` streams observations to the `PolicyServer`, and receives actio
|
||||
<hfoptions id="start_robot_client">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python src/lerobot/scripts/server/robot_client.py \
|
||||
python src/lerobot/async_inference/robot_client.py \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
@@ -171,9 +171,9 @@ python src/lerobot/scripts/server/robot_client.py \
|
||||
import threading
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.scripts.server.configs import RobotClientConfig
|
||||
from lerobot.scripts.server.robot_client import RobotClient
|
||||
from lerobot.scripts.server.helpers import visualize_action_queue_size
|
||||
from lerobot.async_inference.configs import RobotClientConfig
|
||||
from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.async_inference.helpers import visualize_action_queue_size
|
||||
|
||||
# 1. Create the robot instance
|
||||
"""Check out the cameras available in your setup by running `python lerobot/find_cameras.py`"""
|
||||
|
||||
+8
-11
@@ -95,7 +95,6 @@ class HILSerlProcessorConfig:
|
||||
class ObservationConfig:
|
||||
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
|
||||
add_current_to_observation: bool = False # Add motor currents to state
|
||||
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
|
||||
display_cameras: bool = False # Display camera feeds during execution
|
||||
|
||||
class ImagePreprocessingConfig:
|
||||
@@ -105,7 +104,6 @@ class ImagePreprocessingConfig:
|
||||
class GripperConfig:
|
||||
use_gripper: bool = True # Enable gripper control
|
||||
gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage
|
||||
gripper_penalty_in_reward: bool = False # Include gripper penalty in reward
|
||||
|
||||
class ResetConfig:
|
||||
fixed_reset_joint_positions: Any | None = None # Joint positions for reset
|
||||
@@ -288,7 +286,6 @@ You can enable multiple observation processing features simultaneously:
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true,
|
||||
"add_current_to_observation": true,
|
||||
"add_ee_pose_to_observation": false,
|
||||
"display_cameras": false
|
||||
}
|
||||
}
|
||||
@@ -304,19 +301,19 @@ Before collecting demonstrations, you need to determine the appropriate operatio
|
||||
|
||||
This helps simplify the problem of learning on the real robot in two ways: 1) by limiting the robot's operational space to a specific region that solves the task and avoids unnecessary or unsafe exploration, and 2) by allowing training in end-effector space rather than joint space. Empirically, learning in joint space for reinforcement learning in manipulation is often a harder problem - some tasks are nearly impossible to learn in joint space but become learnable when the action space is transformed to end-effector coordinates.
|
||||
|
||||
**Using find_joint_limits.py**
|
||||
**Using lerobot-find-joint-limits**
|
||||
|
||||
This script helps you find the safe operational bounds for your robot's end-effector. Given that you have a follower and leader arm, you can use the script to find the bounds for the follower arm that will be applied during training.
|
||||
Bounding the action space will reduce the redundant exploration of the agent and guarantees safety.
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.find_joint_limits \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue
|
||||
lerobot-find-joint-limits \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue
|
||||
```
|
||||
|
||||
**Workflow**
|
||||
|
||||
@@ -200,7 +200,7 @@ from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderCo
|
||||
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.record import record_loop
|
||||
|
||||
NUM_EPISODES = 5
|
||||
@@ -237,7 +237,7 @@ dataset = LeRobotDataset.create(
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording")
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
@@ -517,7 +517,7 @@ from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerCon
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.policies.factory import make_processor
|
||||
|
||||
@@ -557,7 +557,7 @@ dataset = LeRobotDataset.create(
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording")
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
@@ -277,7 +277,7 @@ leader.disconnect()
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots)
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
|
||||
|
||||
@@ -323,7 +323,7 @@ To replay an episode run the API example below, make sure to change `remote_ip`,
|
||||
python examples/lekiwi/replay.py
|
||||
```
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./il_robots)
|
||||
|
||||
## Evaluate your policy
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ To Install LIBERO, after following LeRobot official instructions, just do:
|
||||
Evaluate a policy on one LIBERO suite:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/eval.py \
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
@@ -52,7 +52,7 @@ python src/lerobot/scripts/eval.py \
|
||||
Benchmark a policy across multiple suites at once:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/eval.py \
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object,libero_spatial \
|
||||
@@ -103,10 +103,11 @@ For reference, here is the **original dataset** published by Physical Intelligen
|
||||
### Example training command
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/train.py \
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--policy.repo_id=${HF_USER}/libero-test \
|
||||
--dataset.repo_id=jadechoghari/smol-libero3 \
|
||||
--policy.load_vlm_weights=true \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--output_dir=./outputs/ \
|
||||
|
||||
@@ -136,13 +136,12 @@ Additionally you can customize mapping or safety limits by editing the processor
|
||||
),
|
||||
```
|
||||
|
||||
- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` and `max_ee_twist_step_rad` are the step limits for the EE pose and can be modified to change the safety limits.
|
||||
- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` are the step limits for the EE pose and can be modified to change the safety limits.
|
||||
|
||||
```examples/phone_to_so100/teleoperate.py
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, RobotActi
|
||||
kinematics=kinematics_solver, end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20, max_ee_twist_step_rad=0.50,
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20,
|
||||
),
|
||||
GripperVelocityToJoint(),
|
||||
],
|
||||
|
||||
@@ -29,7 +29,7 @@ SmolVLA is Hugging Face’s lightweight foundation model for robotics. Designed
|
||||
## Collect a dataset
|
||||
|
||||
SmolVLA is a base model, so fine-tuning on your own data is required for optimal performance in your setup.
|
||||
We recommend recording ~50 episodes of your task as a starting point. Follow our guide to get started: [Recording a Dataset](https://huggingface.co/docs/lerobot/getting_started_real_world_robot#record-a-dataset)
|
||||
We recommend recording ~50 episodes of your task as a starting point. Follow our guide to get started: [Recording a Dataset](./il_robots)
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -93,7 +93,7 @@ lerobot-train --help
|
||||
|
||||
## Evaluate the finetuned model and run it in real-time
|
||||
|
||||
Similarly for when recording an episode, it is recommended that you are logged in to the HuggingFace Hub. You can follow the corresponding steps: [Record a dataset](./getting_started_real_world_robot#record-a-dataset).
|
||||
Similarly for when recording an episode, it is recommended that you are logged in to the HuggingFace Hub. You can follow the corresponding steps: [Record a dataset](./il_robots).
|
||||
Once you are logged in, you can run inference in your setup by doing:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -634,7 +634,7 @@ leader.disconnect()
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots)
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
|
||||
|
||||
@@ -430,7 +430,7 @@ leader.disconnect()
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots)
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
|
||||
|
||||
@@ -44,6 +44,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
@@ -78,16 +79,16 @@ def replay(cfg: ReplayConfig):
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
actions = dataset.hf_dataset.select_columns(ACTION)
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx]["action"]
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features["action"]["names"]):
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
key = f"{name.removeprefix('main_')}.pos"
|
||||
action[key] = action_array[i].item()
|
||||
|
||||
|
||||
@@ -19,11 +19,12 @@ from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
@@ -41,8 +42,8 @@ robot = LeKiwiClient(robot_config)
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
@@ -73,7 +74,7 @@ teleop_action_processor, robot_action_processor, robot_observation_processor = m
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="lekiwi_evaluate")
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
@@ -17,14 +17,15 @@
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
@@ -47,8 +48,8 @@ keyboard = KeyboardTeleop(keyboard_config)
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
@@ -69,7 +70,7 @@ keyboard.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="lekiwi_record")
|
||||
init_rerun(session_name="lekiwi_record")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
@@ -19,6 +19,7 @@ import time
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -34,7 +35,7 @@ robot = LeKiwiClient(robot_config)
|
||||
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns("action")
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
@@ -49,7 +50,7 @@ for idx in range(len(episode_frames)):
|
||||
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Send action to robot
|
||||
|
||||
@@ -20,7 +20,7 @@ from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
@@ -41,7 +41,7 @@ leader_arm.connect()
|
||||
keyboard.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
_init_rerun(session_name="lekiwi_teleop")
|
||||
init_rerun(session_name="lekiwi_teleop")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
@@ -34,16 +34,16 @@ from lerobot.processor.converters import (
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -137,7 +137,7 @@ robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="phone_so100_evaluate")
|
||||
init_rerun(session_name="phone_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
@@ -26,7 +26,6 @@ from lerobot.processor.converters import (
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
@@ -36,12 +35,13 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
@@ -84,7 +84,6 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline[tuple[RobotAction, Rob
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
GripperVelocityToJoint(speed_factor=20.0),
|
||||
],
|
||||
@@ -143,7 +142,7 @@ phone.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="phone_so100_record")
|
||||
init_rerun(session_name="phone_so100_record")
|
||||
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
@@ -28,6 +28,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -66,7 +67,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotOb
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns("action")
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
@@ -81,7 +82,7 @@ for idx in range(len(episode_frames)):
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
|
||||
@@ -33,7 +33,7 @@ from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
@@ -67,7 +67,6 @@ phone_to_robot_joints_processor = RobotProcessorPipeline[tuple[RobotAction, Robo
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
speed_factor=20.0,
|
||||
@@ -87,7 +86,7 @@ robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
_init_rerun(session_name="phone_so100_teleop")
|
||||
init_rerun(session_name="phone_so100_teleop")
|
||||
|
||||
if not robot.is_connected or not teleop_device.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
@@ -34,16 +34,16 @@ from lerobot.processor.converters import (
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -138,7 +138,7 @@ robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="so100_so100_evaluate")
|
||||
init_rerun(session_name="so100_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
@@ -27,7 +27,6 @@ from lerobot.processor.converters import (
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
@@ -35,11 +34,12 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
@@ -101,7 +101,6 @@ ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservati
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
@@ -143,7 +142,7 @@ follower.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
@@ -29,6 +29,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -67,7 +68,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotOb
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns("action")
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
@@ -82,7 +83,7 @@ for idx in range(len(episode_frames)):
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
|
||||
@@ -33,7 +33,7 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
@@ -78,7 +78,6 @@ ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservati
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
@@ -95,7 +94,7 @@ follower.connect()
|
||||
leader.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
_init_rerun(session_name="so100_so100_EE_teleop")
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
|
||||
@@ -20,13 +20,13 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
+86
-10
@@ -162,17 +162,18 @@ all = [
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
lerobot-calibrate="lerobot.calibrate:main"
|
||||
lerobot-find-cameras="lerobot.find_cameras:main"
|
||||
lerobot-find-port="lerobot.find_port:main"
|
||||
lerobot-record="lerobot.record:main"
|
||||
lerobot-replay="lerobot.replay:main"
|
||||
lerobot-setup-motors="lerobot.setup_motors:main"
|
||||
lerobot-teleoperate="lerobot.teleoperate:main"
|
||||
lerobot-eval="lerobot.scripts.eval:main"
|
||||
lerobot-train="lerobot.scripts.train:main"
|
||||
lerobot-calibrate="lerobot.scripts.lerobot_calibrate:main"
|
||||
lerobot-find-cameras="lerobot.scripts.lerobot_find_cameras:main"
|
||||
lerobot-find-port="lerobot.scripts.lerobot_find_port:main"
|
||||
lerobot-record="lerobot.scripts.lerobot_record:main"
|
||||
lerobot-replay="lerobot.scripts.lerobot_replay:main"
|
||||
lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main"
|
||||
lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main"
|
||||
lerobot-eval="lerobot.scripts.lerobot_eval:main"
|
||||
lerobot-train="lerobot.scripts.lerobot_train:main"
|
||||
lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
|
||||
lerobot-info="lerobot.scripts.lerobot_info:main"
|
||||
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
@@ -200,7 +201,7 @@ exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
|
||||
# N: pep8-naming
|
||||
# TODO: Uncomment rules when ready to use
|
||||
select = [
|
||||
"E", "W", "F", "I", "B", "C4", "T20", "N" # "SIM", "A", "S", "D", "RUF", "UP"
|
||||
"E", "W", "F", "I", "B", "C4", "T20", "N", "UP", "SIM" #, "A", "S", "D", "RUF"
|
||||
]
|
||||
ignore = [
|
||||
"E501", # Line too long
|
||||
@@ -266,8 +267,83 @@ default.extend-ignore-identifiers-re = [
|
||||
# color = true
|
||||
# paths = ["src/lerobot"]
|
||||
|
||||
# TODO: Enable mypy gradually module by module across multiple PRs
|
||||
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
|
||||
|
||||
# [tool.mypy]
|
||||
# python_version = "3.10"
|
||||
# warn_return_any = true
|
||||
# warn_unused_configs = true
|
||||
# ignore_missing_imports = false
|
||||
# strict = true
|
||||
# disallow_untyped_defs = true
|
||||
# disallow_incomplete_defs = true
|
||||
# check_untyped_defs = true
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.utils.*"
|
||||
# # include = "src/lerobot/utils/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.configs.*"
|
||||
# # include = "src/lerobot/configs/**/*.py"
|
||||
|
||||
# # Data processing modules
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.processor.*"
|
||||
# # include = "src/lerobot/processor/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.datasets.*"
|
||||
# # include = "src/lerobot/datasets/**/*.py"
|
||||
|
||||
# # Core machine learning modules
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.optim.*"
|
||||
# # include = "src/lerobot/optim/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.model.*"
|
||||
# # include = "src/lerobot/model/**/*.py"
|
||||
|
||||
# # Hardware interfaces
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.cameras.*"
|
||||
# # include = "src/lerobot/cameras/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.motors.*"
|
||||
# # include = "src/lerobot/motors/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.robots.*"
|
||||
# # include = "src/lerobot/robots/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.teleoperators.*"
|
||||
# # include = "src/lerobot/teleoperators/**/*.py"
|
||||
|
||||
# # Complex modules (enable these last)
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.policies.*"
|
||||
# # include = "src/lerobot/policies/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.rl.*"
|
||||
# # include = "src/lerobot/rl/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.envs.*"
|
||||
# # include = "src/lerobot/envs/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.async_inference.*"
|
||||
# # include = "src/lerobot/async_inference/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.transport.*"
|
||||
# # include = "src/lerobot/transport/**/*.py"
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.scripts.*"
|
||||
# # include = "src/lerobot/scripts/**/*.py"
|
||||
|
||||
@@ -18,7 +18,8 @@ from dataclasses import dataclass, field
|
||||
import torch
|
||||
|
||||
from lerobot.robots.config import RobotConfig
|
||||
from lerobot.scripts.server.constants import (
|
||||
|
||||
from .constants import (
|
||||
DEFAULT_FPS,
|
||||
DEFAULT_INFERENCE_LATENCY,
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT,
|
||||
@@ -22,16 +22,15 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
ActionChunk = torch.Tensor
|
||||
|
||||
# observation as received from the robot
|
||||
RawObservation = dict[str, torch.Tensor]
|
||||
@@ -46,7 +45,7 @@ Observation = dict[str, torch.Tensor]
|
||||
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
_, ax = plt.subplots()
|
||||
ax.set_title("Action Queue Size Over Time")
|
||||
ax.set_xlabel("Environment steps")
|
||||
ax.set_ylabel("Action Queue Size")
|
||||
@@ -66,7 +65,7 @@ def validate_robot_cameras_for_policy(
|
||||
|
||||
|
||||
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
|
||||
return hw_to_dataset_features(robot.observation_features, "observation", use_video=False)
|
||||
return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
@@ -141,7 +140,7 @@ def make_lerobot_observation(
|
||||
lerobot_features: dict[str, dict],
|
||||
) -> LeRobotObservation:
|
||||
"""Make a lerobot observation from a raw observation."""
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix="observation")
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR)
|
||||
|
||||
|
||||
def prepare_raw_observation(
|
||||
+10
-9
@@ -15,7 +15,7 @@
|
||||
"""
|
||||
Example:
|
||||
```shell
|
||||
python src/lerobot/scripts/server/policy_server.py \
|
||||
python src/lerobot/async_inference/policy_server.py \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
--fps=30 \
|
||||
@@ -38,9 +38,15 @@ import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
||||
from lerobot.scripts.server.constants import SUPPORTED_POLICIES
|
||||
from lerobot.scripts.server.helpers import (
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks
|
||||
|
||||
from .configs import PolicyServerConfig
|
||||
from .constants import SUPPORTED_POLICIES
|
||||
from .helpers import (
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RemotePolicyConfig,
|
||||
@@ -50,11 +56,6 @@ from lerobot.scripts.server.helpers import (
|
||||
observations_similar,
|
||||
raw_observation_to_observation,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks
|
||||
|
||||
|
||||
class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
+10
-9
@@ -15,7 +15,7 @@
|
||||
"""
|
||||
Example command:
|
||||
```shell
|
||||
python src/lerobot/scripts/server/robot_client.py \
|
||||
python src/lerobot/async_inference/robot_client.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
@@ -57,9 +57,15 @@ from lerobot.robots import ( # noqa: F401
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.scripts.server.configs import RobotClientConfig
|
||||
from lerobot.scripts.server.constants import SUPPORTED_ROBOTS
|
||||
from lerobot.scripts.server.helpers import (
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
|
||||
from .configs import RobotClientConfig
|
||||
from .constants import SUPPORTED_ROBOTS
|
||||
from .helpers import (
|
||||
Action,
|
||||
FPSTracker,
|
||||
Observation,
|
||||
@@ -72,11 +78,6 @@ from lerobot.scripts.server.helpers import (
|
||||
validate_robot_cameras_for_policy,
|
||||
visualize_action_queue_size,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
|
||||
|
||||
class RobotClient:
|
||||
@@ -31,7 +31,7 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..utils import get_cv2_backend, get_cv2_rotation
|
||||
|
||||
@@ -31,7 +31,7 @@ import numpy as np
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig
|
||||
|
||||
@@ -29,7 +29,7 @@ try:
|
||||
except Exception as e:
|
||||
logging.info(f"Could not import realsense: {e}")
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..configs import ColorMode
|
||||
|
||||
@@ -15,14 +15,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import TypeAlias
|
||||
|
||||
from .camera import Camera
|
||||
from .configs import CameraConfig, Cv2Rotation
|
||||
|
||||
IndexOrPath: TypeAlias = int | Path
|
||||
|
||||
|
||||
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]:
|
||||
cameras = {}
|
||||
|
||||
@@ -16,9 +16,6 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot import (
|
||||
policies, # noqa: F401
|
||||
)
|
||||
from lerobot.datasets.transforms import ImageTransformsConfig
|
||||
from lerobot.datasets.video_utils import get_safe_default_codec
|
||||
|
||||
|
||||
@@ -27,9 +27,9 @@ from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
class FeatureType(str, Enum):
|
||||
@@ -38,10 +37,6 @@ class NormalizationMode(str, Enum):
|
||||
IDENTITY = "IDENTITY"
|
||||
|
||||
|
||||
class DictLike(Protocol):
|
||||
def __getitem__(self, key: Any) -> Any: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyFeature:
|
||||
type: FeatureType
|
||||
|
||||
@@ -93,14 +93,13 @@ def update_data_df(df, src_meta, dst_meta):
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices.
|
||||
"""
|
||||
|
||||
def _update(row):
|
||||
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
|
||||
row["index"] = row["index"] + dst_meta.info["total_frames"]
|
||||
task = src_meta.tasks.iloc[row["task_index"]].name
|
||||
row["task_index"] = dst_meta.tasks.loc[task].task_index.item()
|
||||
return row
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
df["index"] = df["index"] + dst_meta.info["total_frames"]
|
||||
|
||||
return df.apply(_update, axis=1)
|
||||
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
|
||||
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def update_meta_data(
|
||||
@@ -126,27 +125,21 @@ def update_meta_data(
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
|
||||
"""
|
||||
|
||||
def _update(row):
|
||||
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"]
|
||||
row["data/chunk_index"] = row["data/chunk_index"] + data_idx["chunk"]
|
||||
row["data/file_index"] = row["data/file_index"] + data_idx["file"]
|
||||
for key, video_idx in videos_idx.items():
|
||||
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + video_idx["chunk"]
|
||||
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + video_idx["file"]
|
||||
row[f"videos/{key}/from_timestamp"] = (
|
||||
row[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
row[f"videos/{key}/to_timestamp"] = (
|
||||
row[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
for key, video_idx in videos_idx.items():
|
||||
df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"]
|
||||
df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"]
|
||||
df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
||||
|
||||
row["dataset_from_index"] = row["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
row["dataset_to_index"] = row["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
|
||||
return row
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
|
||||
return df.apply(_update, axis=1)
|
||||
return df
|
||||
|
||||
|
||||
def aggregate_datasets(
|
||||
|
||||
@@ -27,6 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
)
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
@@ -54,11 +55,11 @@ def resolve_delta_timestamps(
|
||||
"""
|
||||
delta_timestamps = {}
|
||||
for key in ds_meta.features:
|
||||
if key == "next.reward" and cfg.reward_delta_indices is not None:
|
||||
if key == REWARD and cfg.reward_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == "action" and cfg.action_delta_indices is not None:
|
||||
if key == ACTION and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
|
||||
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
|
||||
@@ -31,7 +31,6 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.datasets.utils import (
|
||||
@@ -79,6 +78,7 @@ from lerobot.datasets.video_utils import (
|
||||
get_video_duration_in_s,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
@@ -848,11 +848,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return item
|
||||
|
||||
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
||||
for key, val in padding.items():
|
||||
item[key] = torch.BoolTensor(val)
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return self.num_frames
|
||||
|
||||
@@ -1032,7 +1027,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
|
||||
self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0)
|
||||
|
||||
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None):
|
||||
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
|
||||
"""
|
||||
Batch save videos for multiple episodes.
|
||||
|
||||
@@ -1158,7 +1153,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
}
|
||||
return metadata
|
||||
|
||||
def _save_episode_video(self, video_key: str, episode_index: int):
|
||||
def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||
# Encode episode frames into a temporary video
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||
@@ -1263,7 +1258,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
@@ -1396,11 +1391,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
||||
|
||||
@property
|
||||
def repo_index_to_id(self):
|
||||
"""Return the inverse mapping if repo_id_to_index."""
|
||||
return {v: k for k, v in self.repo_id_to_index}
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection.
|
||||
@@ -1431,7 +1421,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""Keys to access image and video stream from cameras."""
|
||||
keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, (datasets.Image, VideoFrame)):
|
||||
if isinstance(feats, (datasets.Image | VideoFrame)):
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
def create_initial_features(
|
||||
@@ -92,8 +92,8 @@ def aggregate_pipeline_dataset_features(
|
||||
|
||||
# Intermediate storage for categorized and filtered features.
|
||||
processed_features: dict[str, dict[str, Any]] = {
|
||||
"action": {},
|
||||
"observation": {},
|
||||
ACTION: {},
|
||||
OBS_STR: {},
|
||||
}
|
||||
images_token = OBS_IMAGES.split(".")[-1]
|
||||
|
||||
@@ -125,17 +125,15 @@ def aggregate_pipeline_dataset_features(
|
||||
# 3. Add the feature to the appropriate group with a clean name.
|
||||
name = strip_prefix(key, PREFIXES_TO_STRIP)
|
||||
if is_action:
|
||||
processed_features["action"][name] = value
|
||||
processed_features[ACTION][name] = value
|
||||
else:
|
||||
processed_features["observation"][name] = value
|
||||
processed_features[OBS_STR][name] = value
|
||||
|
||||
# Convert the processed features into the final dataset format.
|
||||
dataset_features = {}
|
||||
if processed_features["action"]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos))
|
||||
if processed_features["observation"]:
|
||||
dataset_features.update(
|
||||
hw_to_dataset_features(processed_features["observation"], "observation", use_videos)
|
||||
)
|
||||
if processed_features[ACTION]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features[ACTION], ACTION, use_videos))
|
||||
if processed_features[OBS_STR]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos))
|
||||
|
||||
return dataset_features
|
||||
|
||||
@@ -13,67 +13,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.video_utils import encode_video_frames
|
||||
|
||||
|
||||
def concatenate_episodes(ep_dicts):
|
||||
data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
data_dict[key].append(x)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def save_image(img_array, i, out_dir):
|
||||
img = PIL.Image.fromarray(img_array)
|
||||
img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
|
||||
|
||||
def get_default_encoding() -> dict:
|
||||
"""Returns the default ffmpeg encoding parameters used by `encode_video_frames`."""
|
||||
signature = inspect.signature(encode_video_frames)
|
||||
return {
|
||||
k: v.default
|
||||
for k, v in signature.parameters.items()
|
||||
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
}
|
||||
|
||||
|
||||
def check_repo_id(repo_id: str) -> None:
|
||||
if len(repo_id.split("/")) != 2:
|
||||
raise ValueError(
|
||||
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
|
||||
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
|
||||
|
||||
@@ -21,7 +21,6 @@ import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
Backtrackable,
|
||||
@@ -38,6 +37,7 @@ from lerobot.datasets.video_utils import (
|
||||
VideoDecoderCache,
|
||||
decode_video_frames_torchcodec,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
|
||||
|
||||
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
@@ -298,9 +298,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
return padding_mask
|
||||
|
||||
def make_frame(
|
||||
self, dataset_iterator: Backtrackable, previous_dataset_iterator: Backtrackable | None = None
|
||||
) -> Generator:
|
||||
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
|
||||
"""Makes a frame starting from a dataset iterator"""
|
||||
item = next(dataset_iterator)
|
||||
item = item_to_torch(item)
|
||||
|
||||
@@ -120,7 +120,7 @@ class SharpnessJitter(Transform):
|
||||
self.sharpness = self._check_input(sharpness)
|
||||
|
||||
def _check_input(self, sharpness):
|
||||
if isinstance(sharpness, (int, float)):
|
||||
if isinstance(sharpness, (int | float)):
|
||||
if sharpness < 0:
|
||||
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||
|
||||
@@ -21,7 +21,7 @@ from collections import deque
|
||||
from collections.abc import Iterable, Iterator
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any, Deque, Generic, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -43,6 +43,7 @@ from lerobot.datasets.backward_compatibility import (
|
||||
BackwardCompatibilityError,
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
@@ -66,18 +67,6 @@ DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{fram
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
||||
LEGACY_DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
# Metadata will go there
|
||||
---
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
## {}
|
||||
|
||||
"""
|
||||
|
||||
DEFAULT_FEATURES = {
|
||||
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
@@ -218,13 +207,13 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
"""
|
||||
serialized_dict = {}
|
||||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
if isinstance(value, (torch.Tensor | np.ndarray)):
|
||||
serialized_dict[key] = value.tolist()
|
||||
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
|
||||
elif isinstance(value, list) and isinstance(value[0], (int | float | list)):
|
||||
serialized_dict[key] = value
|
||||
elif isinstance(value, np.generic):
|
||||
serialized_dict[key] = value.item()
|
||||
elif isinstance(value, (int, float)):
|
||||
elif isinstance(value, (int | float)):
|
||||
serialized_dict[key] = value
|
||||
else:
|
||||
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||
@@ -382,12 +371,6 @@ def load_episodes(local_dir: Path) -> datasets.Dataset:
|
||||
return episodes
|
||||
|
||||
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[int, dict[str, dict[str, np.ndarray]]]:
|
||||
return dict.fromkeys(episodes, stats)
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
||||
) -> np.ndarray:
|
||||
@@ -645,14 +628,14 @@ def hw_to_dataset_features(
|
||||
}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
|
||||
if joint_fts and prefix == "action":
|
||||
if joint_fts and prefix == ACTION:
|
||||
features[prefix] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
"names": list(joint_fts),
|
||||
}
|
||||
|
||||
if joint_fts and prefix == "observation":
|
||||
if joint_fts and prefix == OBS_STR:
|
||||
features[f"{prefix}.state"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
@@ -728,11 +711,11 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == "observation.environment_state":
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith("observation"):
|
||||
elif key.startswith(OBS_STR):
|
||||
type = FeatureType.STATE
|
||||
elif key.startswith("action"):
|
||||
elif key.startswith(ACTION):
|
||||
type = FeatureType.ACTION
|
||||
else:
|
||||
continue
|
||||
@@ -1196,7 +1179,7 @@ def item_to_torch(item: dict) -> dict:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
@@ -1270,8 +1253,8 @@ class Backtrackable(Generic[T]):
|
||||
raise ValueError("lookahead must be > 0")
|
||||
|
||||
self._source: Iterator[T] = iter(iterable)
|
||||
self._back_buf: Deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._back_buf: deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._cursor: int = 0
|
||||
self._history = history
|
||||
self._lookahead = lookahead
|
||||
@@ -1345,12 +1328,6 @@ class Backtrackable(Generic[T]):
|
||||
# When cursor<0, slice so the order remains chronological
|
||||
return list(self._back_buf)[: self._cursor or None]
|
||||
|
||||
def lookahead_buffer(self) -> list[T]:
|
||||
"""
|
||||
Return a copy of the current lookahead buffer.
|
||||
"""
|
||||
return list(self._ahead_buf)
|
||||
|
||||
def can_peek_back(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can go back `steps` items without raising an IndexError.
|
||||
@@ -1376,31 +1353,6 @@ class Backtrackable(Generic[T]):
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
def reset_cursor(self) -> None:
|
||||
"""
|
||||
Reset cursor to the most recent position (equivalent to calling next()
|
||||
until you're back to the latest item).
|
||||
"""
|
||||
self._cursor = 0
|
||||
|
||||
def clear_ahead_buffer(self) -> None:
|
||||
"""
|
||||
Clear the ahead buffer, discarding any pre-fetched items.
|
||||
"""
|
||||
self._ahead_buf.clear()
|
||||
|
||||
def switch_source_iterable(self, new_source: Iterable[T]) -> None:
|
||||
"""
|
||||
Switch the source of the backtrackable to a new iterable, keeping the history.
|
||||
|
||||
This is useful when iterating over a sequence of datasets. The history from the
|
||||
previous source is kept, but the lookahead buffer is cleared. The cursor is reset
|
||||
to the present.
|
||||
"""
|
||||
self._source = iter(new_source)
|
||||
self.clear_ahead_buffer()
|
||||
self.reset_cursor()
|
||||
|
||||
|
||||
def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset:
|
||||
"""
|
||||
|
||||
@@ -34,6 +34,7 @@ python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -46,7 +47,6 @@ from datasets import Dataset, Features, Image
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from requests import HTTPError
|
||||
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
@@ -71,6 +71,8 @@ from lerobot.datasets.utils import (
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
V21 = "v2.1"
|
||||
|
||||
@@ -144,6 +146,7 @@ def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
|
||||
|
||||
def convert_tasks(root, new_root):
|
||||
logging.info(f"Converting tasks from {root} to {new_root}")
|
||||
tasks, _ = legacy_load_tasks(root)
|
||||
task_indices = tasks.keys()
|
||||
task_strings = tasks.values()
|
||||
@@ -185,7 +188,10 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
num_frames = 0
|
||||
paths_to_cat = []
|
||||
episodes_metadata = []
|
||||
for ep_path in ep_paths:
|
||||
|
||||
logging.info(f"Converting data files from {len(ep_paths)} episodes")
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"):
|
||||
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
||||
ep_num_frames = get_parquet_num_frames(ep_path)
|
||||
ep_metadata = {
|
||||
@@ -209,7 +215,6 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = ep_size_in_mb
|
||||
num_frames = ep_num_frames
|
||||
paths_to_cat = [ep_path]
|
||||
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
@@ -236,6 +241,8 @@ def get_image_keys(root):
|
||||
|
||||
|
||||
def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int):
|
||||
logging.info(f"Converting videos from {root} to {new_root}")
|
||||
|
||||
video_keys = get_video_keys(root)
|
||||
if len(video_keys) == 0:
|
||||
return None
|
||||
@@ -254,7 +261,7 @@ def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int):
|
||||
episods_metadata = []
|
||||
num_cameras = len(video_keys)
|
||||
num_episodes = num_eps_per_cam[0]
|
||||
for ep_idx in range(num_episodes):
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes), desc="convert videos"):
|
||||
# Sanity check
|
||||
ep_ids = [eps_metadata_per_cam[cam_idx][ep_idx]["episode_index"] for cam_idx in range(num_cameras)]
|
||||
ep_ids += [ep_idx]
|
||||
@@ -281,6 +288,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
duration_in_s = 0.0
|
||||
paths_to_cat = []
|
||||
episodes_metadata = []
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
@@ -374,6 +382,8 @@ def generate_episode_metadata_dict(
|
||||
|
||||
|
||||
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
|
||||
logging.info(f"Converting episodes metadata from {root} to {new_root}")
|
||||
|
||||
episodes_legacy_metadata = legacy_load_episodes(root)
|
||||
episodes_stats = legacy_load_episodes_stats(root)
|
||||
|
||||
@@ -405,6 +415,7 @@ def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
|
||||
info["data_path"] = DEFAULT_DATA_PATH
|
||||
info["video_path"] = DEFAULT_VIDEO_PATH
|
||||
info["fps"] = int(info["fps"])
|
||||
logging.info(f"Converting info from {root} to {new_root}")
|
||||
for key in info["features"]:
|
||||
if info["features"][key]["dtype"] == "video":
|
||||
# already has fps in video_info
|
||||
@@ -469,6 +480,7 @@ def convert_dataset(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
|
||||
@@ -428,7 +428,7 @@ def concatenate_video_files(
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
for input_path in input_video_paths:
|
||||
tmp_concatenate_file.write(f"file '{str(input_path)}'\n")
|
||||
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
|
||||
tmp_concatenate_file.flush()
|
||||
tmp_concatenate_path = tmp_concatenate_file.name
|
||||
|
||||
@@ -437,7 +437,9 @@ def concatenate_video_files(
|
||||
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
|
||||
) # safe = 0 allows absolute paths as well as relative paths
|
||||
|
||||
tmp_output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
|
||||
output_container = av.open(
|
||||
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
|
||||
) # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||
@@ -585,19 +587,6 @@ def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_image_pixel_channels(image: Image):
|
||||
if image.mode == "L":
|
||||
return 1 # Grayscale
|
||||
elif image.mode == "LA":
|
||||
return 2 # Grayscale + Alpha
|
||||
elif image.mode == "RGB":
|
||||
return 3 # RGB
|
||||
elif image.mode == "RGBA":
|
||||
return 4 # RGBA
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
"""
|
||||
Get the duration of a video file in seconds using PyAV.
|
||||
|
||||
@@ -19,9 +19,9 @@ from typing import Any
|
||||
import draccus
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.robots import RobotConfig
|
||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,12 +53,12 @@ class AlohaEnv(EnvConfig):
|
||||
render_mode: str = "rgb_array"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"top": f"{OBS_IMAGE}.top",
|
||||
"pixels/top": f"{OBS_IMAGES}.top",
|
||||
@@ -93,13 +93,13 @@ class PushtEnv(EnvConfig):
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"environment_state": OBS_ENV_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
@@ -135,13 +135,13 @@ class XarmEnv(EnvConfig):
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
@@ -193,7 +193,6 @@ class ObservationConfig:
|
||||
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
display_cameras: bool = False
|
||||
|
||||
|
||||
@@ -203,7 +202,6 @@ class GripperConfig:
|
||||
|
||||
use_gripper: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -259,12 +257,12 @@ class LiberoEnv(EnvConfig):
|
||||
camera_name_mapping: dict[str, str] | None = (None,)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels/agentview_image": f"{OBS_IMAGES}.image",
|
||||
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
|
||||
|
||||
@@ -35,7 +35,7 @@ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize camera_name into a non-empty list of strings."""
|
||||
if isinstance(camera_name, str):
|
||||
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
||||
elif isinstance(camera_name, (list, tuple)):
|
||||
elif isinstance(camera_name, (list | tuple)):
|
||||
cams = [str(c).strip() for c in camera_name if str(c).strip()]
|
||||
else:
|
||||
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
|
||||
|
||||
@@ -26,6 +26,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.utils import get_channel_first_image_shape
|
||||
|
||||
|
||||
@@ -41,9 +42,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
return_observations = {}
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
imgs = {OBS_IMAGE: observations["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
@@ -72,13 +73,13 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
if env_state.dim() == 1:
|
||||
env_state = env_state.unsqueeze(0)
|
||||
|
||||
return_observations["observation.environment_state"] = env_state
|
||||
return_observations[OBS_ENV_STATE] = env_state
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
return_observations["observation.state"] = agent_pos
|
||||
return_observations[OBS_STATE] = agent_pos
|
||||
|
||||
return return_observations
|
||||
|
||||
@@ -182,10 +183,10 @@ def _(env: Mapping) -> None:
|
||||
|
||||
@close_envs.register
|
||||
def _(envs: Sequence) -> None:
|
||||
if isinstance(envs, (str, bytes)):
|
||||
if isinstance(envs, (str | bytes)):
|
||||
return
|
||||
for v in envs:
|
||||
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
|
||||
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str | bytes)):
|
||||
close_envs(v)
|
||||
elif hasattr(v, "close"):
|
||||
_close_single_env(v)
|
||||
|
||||
@@ -22,7 +22,7 @@ import logging
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
|
||||
from lerobot.utils.encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from .tables import (
|
||||
|
||||
@@ -17,7 +17,7 @@ from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pprint import pformat
|
||||
|
||||
from lerobot.utils.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
|
||||
from .tables import (
|
||||
|
||||
@@ -32,7 +32,7 @@ import serial
|
||||
from deepdiff import DeepDiff
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
NameOrID: TypeAlias = str | int
|
||||
@@ -99,12 +99,6 @@ class Motor:
|
||||
norm_mode: MotorNormMode
|
||||
|
||||
|
||||
class JointOutOfRangeError(Exception):
|
||||
def __init__(self, message="Joint is out of range"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class PortHandler(Protocol):
|
||||
def __init__(self, port_name):
|
||||
self.is_open: bool
|
||||
@@ -348,7 +342,7 @@ class MotorsBus(abc.ABC):
|
||||
raise TypeError(motors)
|
||||
|
||||
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
|
||||
if isinstance(values, (int, float)):
|
||||
if isinstance(values, (int | float)):
|
||||
return dict.fromkeys(self.ids, values)
|
||||
elif isinstance(values, dict):
|
||||
return {self.motors[motor].id: val for motor, val in values.items()}
|
||||
@@ -675,7 +669,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str, int)):
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
@@ -703,7 +697,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str, int)):
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
@@ -739,7 +733,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str, int)):
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
|
||||
@@ -22,11 +22,11 @@ import draccus
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.constants import (
|
||||
from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json
|
||||
from lerobot.utils.constants import (
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
)
|
||||
from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@ import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from lerobot.constants import SCHEDULER_STATE
|
||||
from lerobot.datasets.utils import write_json
|
||||
from lerobot.utils.constants import SCHEDULER_STATE
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
|
||||
@@ -33,9 +33,9 @@ from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class ACTPolicy(PreTrainedPolicy):
|
||||
@@ -394,25 +394,22 @@ class ACT(nn.Module):
|
||||
latent dimension.
|
||||
"""
|
||||
if self.config.use_vae and self.training:
|
||||
assert "action" in batch, (
|
||||
assert ACTION in batch, (
|
||||
"actions must be provided when using the variational objective in training mode."
|
||||
)
|
||||
|
||||
if "observation.images" in batch:
|
||||
batch_size = batch["observation.images"][0].shape[0]
|
||||
else:
|
||||
batch_size = batch["observation.environment_state"].shape[0]
|
||||
batch_size = batch[OBS_IMAGES][0].shape[0] if OBS_IMAGES in batch else batch[OBS_ENV_STATE].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and "action" in batch and self.training:
|
||||
if self.config.use_vae and ACTION in batch and self.training:
|
||||
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
||||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.config.robot_state_feature:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch[ACTION]) # (B, S, D)
|
||||
|
||||
if self.config.robot_state_feature:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
@@ -430,7 +427,7 @@ class ACT(nn.Module):
|
||||
cls_joint_is_pad = torch.full(
|
||||
(batch_size, 2 if self.config.robot_state_feature else 1),
|
||||
False,
|
||||
device=batch["observation.state"].device,
|
||||
device=batch[OBS_STATE].device,
|
||||
)
|
||||
key_padding_mask = torch.cat(
|
||||
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
|
||||
@@ -454,7 +451,7 @@ class ACT(nn.Module):
|
||||
mu = log_sigma_x2 = None
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
batch[OBS_STATE].device
|
||||
)
|
||||
|
||||
# Prepare transformer encoder inputs.
|
||||
@@ -462,18 +459,16 @@ class ACT(nn.Module):
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
# Robot state token.
|
||||
if self.config.robot_state_feature:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch[OBS_STATE]))
|
||||
# Environment state token.
|
||||
if self.config.env_state_feature:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
)
|
||||
encoder_in_tokens.append(self.encoder_env_state_input_proj(batch[OBS_ENV_STATE]))
|
||||
|
||||
if self.config.image_features:
|
||||
# For a list of images, the H and W may vary but H*W is constant.
|
||||
# NOTE: If modifying this section, verify on MPS devices that
|
||||
# gradients remain stable (no explosions or NaNs).
|
||||
for img in batch["observation.images"]:
|
||||
for img in batch[OBS_IMAGES]:
|
||||
cam_features = self.backbone(img)["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
||||
|
||||
@@ -17,7 +17,6 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -29,6 +28,7 @@ from lerobot.processor import (
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_act_pre_post_processors(
|
||||
|
||||
@@ -33,7 +33,6 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import (
|
||||
@@ -42,6 +41,7 @@ from lerobot.policies.utils import (
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class DiffusionPolicy(PreTrainedPolicy):
|
||||
@@ -81,13 +81,13 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -234,7 +234,7 @@ class DiffusionModel(nn.Module):
|
||||
if self.config.image_features:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
images_per_camera = einops.rearrange(batch[OBS_IMAGES], "b s n ... -> n (b s) ...")
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
@@ -249,7 +249,7 @@ class DiffusionModel(nn.Module):
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
@@ -275,7 +275,7 @@ class DiffusionModel(nn.Module):
|
||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||
}
|
||||
"""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Encode image features and concatenate them all together along with the state vector.
|
||||
@@ -306,10 +306,10 @@ class DiffusionModel(nn.Module):
|
||||
}
|
||||
"""
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
|
||||
assert "observation.images" in batch or "observation.environment_state" in batch
|
||||
n_obs_steps = batch["observation.state"].shape[1]
|
||||
horizon = batch["action"].shape[1]
|
||||
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
|
||||
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
||||
n_obs_steps = batch[OBS_STATE].shape[1]
|
||||
horizon = batch[ACTION].shape[1]
|
||||
assert horizon == self.config.horizon
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
@@ -317,7 +317,7 @@ class DiffusionModel(nn.Module):
|
||||
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||
|
||||
# Forward diffusion.
|
||||
trajectory = batch["action"]
|
||||
trajectory = batch[ACTION]
|
||||
# Sample noise to add to the trajectory.
|
||||
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
||||
# Sample a random noising timestep for each item in the batch.
|
||||
@@ -338,7 +338,7 @@ class DiffusionModel(nn.Module):
|
||||
if self.config.prediction_type == "epsilon":
|
||||
target = eps
|
||||
elif self.config.prediction_type == "sample":
|
||||
target = batch["action"]
|
||||
target = batch[ACTION]
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -30,6 +29,7 @@ from lerobot.processor import (
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_diffusion_pre_post_processors(
|
||||
|
||||
@@ -24,7 +24,6 @@ from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
@@ -46,6 +45,7 @@ from lerobot.processor.converters import (
|
||||
transition_to_batch,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@@ -113,7 +114,7 @@ class PI0Config(PreTrainedConfig):
|
||||
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
key = f"{OBS_IMAGES}.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
def display(tensor: torch.Tensor):
|
||||
@@ -60,26 +61,26 @@ def main():
|
||||
|
||||
# Override stats
|
||||
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
|
||||
dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
|
||||
dataset_meta.stats[OBS_STATE]["mean"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
|
||||
dataset_meta.stats[OBS_STATE]["std"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Create LeRobot batch from Jax
|
||||
batch = {}
|
||||
for cam_key, uint_chw_array in example["images"].items():
|
||||
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch["observation.state"] = torch.from_numpy(example["state"])
|
||||
batch["action"] = torch.from_numpy(outputs["actions"])
|
||||
batch[f"{OBS_IMAGES}.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch[OBS_STATE] = torch.from_numpy(example["state"])
|
||||
batch[ACTION] = torch.from_numpy(outputs["actions"])
|
||||
batch["task"] = example["prompt"]
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
del batch["observation.images.cam_low"]
|
||||
del batch[f"{OBS_IMAGES}.cam_low"]
|
||||
elif model_name == "pi0_aloha_sim":
|
||||
batch["observation.images.top"] = batch["observation.images.cam_high"]
|
||||
del batch["observation.images.cam_high"]
|
||||
batch[f"{OBS_IMAGES}.top"] = batch[f"{OBS_IMAGES}.cam_high"]
|
||||
del batch[f"{OBS_IMAGES}.cam_high"]
|
||||
|
||||
# Batchify
|
||||
for key in batch:
|
||||
@@ -116,7 +117,7 @@ def main():
|
||||
actions.append(action)
|
||||
|
||||
actions = torch.stack(actions, dim=1)
|
||||
pi_actions = batch["action"]
|
||||
pi_actions = batch[ACTION]
|
||||
print("actions")
|
||||
display(actions)
|
||||
print()
|
||||
|
||||
@@ -57,13 +57,13 @@ import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.paligemma_with_expert import (
|
||||
PaliGemmaWithExpertConfig,
|
||||
PaliGemmaWithExpertModel,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -35,6 +34,7 @@ from lerobot.processor import (
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
|
||||
|
||||
@@ -6,6 +6,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0fast")
|
||||
@@ -99,7 +100,7 @@ class PI0FASTConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
key = f"{OBS_IMAGES}.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
|
||||
@@ -57,9 +57,9 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
|
||||
from transformers.cache_utils import HybridCache, StaticCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
PRECISION = {
|
||||
"float16": torch.float16,
|
||||
|
||||
@@ -18,7 +18,6 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -30,6 +29,7 @@ from lerobot.processor import (
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_pi0fast_pre_post_processors(
|
||||
|
||||
@@ -19,8 +19,8 @@ from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def is_image_feature(key: str) -> bool:
|
||||
@@ -139,8 +139,6 @@ class SACConfig(PreTrainedConfig):
|
||||
# Training parameter
|
||||
# Number of steps for online training
|
||||
online_steps: int = 1000000
|
||||
# Seed for the online environment
|
||||
online_env_seed: int = 10000
|
||||
# Capacity of the online replay buffer
|
||||
online_buffer_capacity: int = 100000
|
||||
# Capacity of the offline replay buffer
|
||||
@@ -225,7 +223,7 @@ class SACConfig(PreTrainedConfig):
|
||||
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
if ACTION not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
|
||||
@@ -31,6 +31,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
|
||||
@@ -50,7 +51,7 @@ class SACPolicy(
|
||||
self.config = config
|
||||
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
continuous_action_dim = config.output_features[ACTION].shape[0]
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
@@ -157,7 +158,7 @@ class SACPolicy(
|
||||
The computed loss tensor
|
||||
"""
|
||||
# Extract common components from batch
|
||||
actions: Tensor = batch["action"]
|
||||
actions: Tensor = batch[ACTION]
|
||||
observations: dict[str, Tensor] = batch["state"]
|
||||
observation_features: Tensor = batch.get("observation_feature")
|
||||
|
||||
@@ -513,17 +514,17 @@ class SACObservationEncoder(nn.Module):
|
||||
)
|
||||
|
||||
def _init_state_layers(self) -> None:
|
||||
self.has_env = "observation.environment_state" in self.config.input_features
|
||||
self.has_state = "observation.state" in self.config.input_features
|
||||
self.has_env = OBS_ENV_STATE in self.config.input_features
|
||||
self.has_state = OBS_STATE in self.config.input_features
|
||||
if self.has_env:
|
||||
dim = self.config.input_features["observation.environment_state"].shape[0]
|
||||
dim = self.config.input_features[OBS_ENV_STATE].shape[0]
|
||||
self.env_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
if self.has_state:
|
||||
dim = self.config.input_features["observation.state"].shape[0]
|
||||
dim = self.config.input_features[OBS_STATE].shape[0]
|
||||
self.state_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
@@ -549,9 +550,9 @@ class SACObservationEncoder(nn.Module):
|
||||
cache = self.get_cached_image_features(obs)
|
||||
parts.append(self._encode_images(cache, detach))
|
||||
if self.has_env:
|
||||
parts.append(self.env_encoder(obs["observation.environment_state"]))
|
||||
parts.append(self.env_encoder(obs[OBS_ENV_STATE]))
|
||||
if self.has_state:
|
||||
parts.append(self.state_encoder(obs["observation.state"]))
|
||||
parts.append(self.state_encoder(obs[OBS_STATE]))
|
||||
if parts:
|
||||
return torch.cat(parts, dim=-1)
|
||||
|
||||
@@ -1060,15 +1061,3 @@ class TanhMultivariateNormalDiag(TransformedDistribution):
|
||||
x = transform(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params = {}
|
||||
for outer_key, inner_dict in normalization_params.items():
|
||||
converted_params[outer_key] = {}
|
||||
for key, value in inner_dict.items():
|
||||
converted_params[outer_key][key] = torch.tensor(value)
|
||||
if "image" in outer_key:
|
||||
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||
|
||||
return converted_params
|
||||
|
||||
@@ -19,7 +19,6 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -31,6 +30,7 @@ from lerobot.processor import (
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_sac_pre_post_processors(
|
||||
|
||||
@@ -19,6 +19,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig, OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
||||
@@ -69,7 +70,7 @@ class RewardClassifierConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate feature configurations."""
|
||||
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||
has_image = any(key.startswith(OBS_IMAGE) for key in self.input_features)
|
||||
if not has_image:
|
||||
raise ValueError(
|
||||
"You must provide an image observation (key starting with 'observation.image') in the input features"
|
||||
|
||||
@@ -19,9 +19,9 @@ import logging
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import OBS_IMAGE, REWARD
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("smolvla")
|
||||
@@ -117,7 +118,7 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
key = f"{OBS_IMAGES}.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
|
||||
@@ -59,13 +59,13 @@ import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.policies.utils import (
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -34,6 +33,7 @@ from lerobot.processor import (
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_smolvla_pre_post_processors(
|
||||
|
||||
@@ -35,10 +35,10 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD
|
||||
|
||||
|
||||
class TDMPCPolicy(PreTrainedPolicy):
|
||||
@@ -91,13 +91,13 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
OBS_STATE: deque(maxlen=1),
|
||||
ACTION: deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self.config.image_features:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
self._queues[OBS_IMAGE] = deque(maxlen=1)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
self._queues[OBS_ENV_STATE] = deque(maxlen=1)
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
@@ -325,7 +325,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
action = batch[ACTION] # (t, b, action_dim)
|
||||
reward = batch[REWARD] # (t, b)
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
observations = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)}
|
||||
|
||||
# Apply random image augmentations.
|
||||
if self.config.image_features and self.config.max_random_shift_ratio > 0:
|
||||
@@ -387,10 +387,10 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||
# `z_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# `z_targets` depends on the next observation.
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
@@ -403,7 +403,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||
* ~batch["next.reward_is_pad"]
|
||||
# `reward_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
@@ -419,11 +419,11 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
@@ -441,7 +441,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
temporal_loss_coeffs
|
||||
* raw_v_value_loss
|
||||
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
@@ -477,7 +477,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
* mse
|
||||
* temporal_loss_coeffs
|
||||
# `action_preds` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).mean()
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -30,6 +29,7 @@ from lerobot.processor import (
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_tdmpc_pre_post_processors(
|
||||
|
||||
@@ -82,7 +82,6 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
gpt_n_head: Number of headers of GPT
|
||||
gpt_hidden_dim: Size of hidden dimensions of GPT
|
||||
dropout: Dropout rate for GPT
|
||||
mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
|
||||
offset_loss_weight: A constant that is multiplied to the offset loss
|
||||
primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss
|
||||
secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss
|
||||
@@ -125,7 +124,6 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
gpt_n_head: int = 8
|
||||
gpt_hidden_dim: int = 512
|
||||
dropout: float = 0.1
|
||||
mlp_hidden_dim: int = 1024
|
||||
offset_loss_weight: float = 10000.0
|
||||
primary_code_loss_weight: float = 5.0
|
||||
secondary_code_loss_weight: float = 0.5
|
||||
|
||||
@@ -27,11 +27,11 @@ import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
@@ -133,7 +133,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
batch.pop(ACTION)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
# NOTE: It's important that this happens after stacking the images into a single key.
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
@@ -340,14 +340,12 @@ class VQBeTModel(nn.Module):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.images"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
assert set(batch).issuperset({OBS_STATE, OBS_IMAGES})
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Extract image feature (first combine batch and sequence dims).
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
img_features = self.rgb_encoder(einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ..."))
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
|
||||
@@ -359,9 +357,7 @@ class VQBeTModel(nn.Module):
|
||||
img_features
|
||||
) # (batch, obs_step, number of different cameras, projection dims)
|
||||
input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))]
|
||||
input_tokens.append(
|
||||
self.state_projector(batch["observation.state"])
|
||||
) # (batch, obs_step, projection dims)
|
||||
input_tokens.append(self.state_projector(batch[OBS_STATE])) # (batch, obs_step, projection dims)
|
||||
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
|
||||
# Interleave tokens by stacking and rearranging.
|
||||
input_tokens = torch.stack(input_tokens, dim=2)
|
||||
|
||||
@@ -19,7 +19,6 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -31,6 +30,7 @@ from lerobot.processor import (
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_vqbet_pre_post_processors(
|
||||
|
||||
@@ -231,16 +231,6 @@ class GPT(nn.Module):
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
|
||||
def crop_block_size(self, gpt_block_size):
|
||||
# model surgery to decrease the block size if necessary
|
||||
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
|
||||
# but want to use a smaller block size for some smaller, simpler model
|
||||
assert gpt_block_size <= self.config.gpt_block_size
|
||||
self.config.gpt_block_size = gpt_block_size
|
||||
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
|
||||
for block in self.transformer.h:
|
||||
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
|
||||
|
||||
def configure_parameters(self):
|
||||
"""
|
||||
This long function is unfortunately doing something very simple and is being very defensive:
|
||||
@@ -270,13 +260,11 @@ class GPT(nn.Module):
|
||||
param_dict = dict(self.named_parameters())
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
assert len(inter_params) == 0, (
|
||||
f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
|
||||
)
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
"parameters {} were not separated into either decay/no_decay set!".format(
|
||||
str(param_dict.keys() - union_params),
|
||||
)
|
||||
f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
|
||||
)
|
||||
|
||||
decay = [param_dict[pn] for pn in sorted(decay)]
|
||||
|
||||
@@ -25,7 +25,7 @@ from dataclasses import dataclass, field
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .core import EnvTransition, PolicyAction
|
||||
from .pipeline import (
|
||||
|
||||
@@ -23,6 +23,8 @@ from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_PREFIX, REWARD, TRUNCATED
|
||||
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
|
||||
|
||||
@@ -342,20 +344,20 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
||||
if not isinstance(batch, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
|
||||
|
||||
action = batch.get("action")
|
||||
action = batch.get(ACTION)
|
||||
if action is not None and not isinstance(action, PolicyAction):
|
||||
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
|
||||
|
||||
# Extract observation and complementary data keys.
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)}
|
||||
complementary_data = _extract_complementary_data(batch)
|
||||
|
||||
return create_transition(
|
||||
observation=observation_keys if observation_keys else None,
|
||||
action=batch.get("action"),
|
||||
reward=batch.get("next.reward", 0.0),
|
||||
done=batch.get("next.done", False),
|
||||
truncated=batch.get("next.truncated", False),
|
||||
action=batch.get(ACTION),
|
||||
reward=batch.get(REWARD, 0.0),
|
||||
done=batch.get(DONE, False),
|
||||
truncated=batch.get(TRUNCATED, False),
|
||||
info=batch.get("info", {}),
|
||||
complementary_data=complementary_data if complementary_data else None,
|
||||
)
|
||||
@@ -377,10 +379,10 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
|
||||
|
||||
batch = {
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||
"next.done": transition.get(TransitionKey.DONE, False),
|
||||
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
|
||||
ACTION: transition.get(TransitionKey.ACTION),
|
||||
REWARD: transition.get(TransitionKey.REWARD, 0.0),
|
||||
DONE: transition.get(TransitionKey.DONE, False),
|
||||
TRUNCATED: transition.get(TransitionKey.TRUNCATED, False),
|
||||
"info": transition.get(TransitionKey.INFO, {}),
|
||||
}
|
||||
|
||||
|
||||
@@ -83,14 +83,12 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
|
||||
Attributes:
|
||||
position_scale: A factor to scale the delta position inputs.
|
||||
rotation_scale: A factor to scale the delta rotation inputs (currently unused).
|
||||
noise_threshold: The magnitude below which delta inputs are considered noise
|
||||
and do not trigger an "enabled" state.
|
||||
"""
|
||||
|
||||
# Scale factors for delta movements
|
||||
position_scale: float = 1.0
|
||||
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
|
||||
noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
|
||||
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
|
||||
@@ -340,7 +340,7 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
|
||||
raw_joint_positions = complementary_data.get("raw_joint_positions", None)
|
||||
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
||||
if raw_joint_positions is None:
|
||||
return complementary_data
|
||||
|
||||
|
||||
@@ -20,12 +20,12 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_STATE
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -59,6 +59,7 @@ from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
@@ -196,7 +197,7 @@ def detect_features_and_norm_modes(
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
elif ACTION in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE # Default
|
||||
@@ -215,7 +216,7 @@ def detect_features_and_norm_modes(
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key or "joint" in key or "position" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
elif ACTION in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
@@ -321,7 +322,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
elif ACTION in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
@@ -26,6 +26,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
from .converters import from_tensor_to_numpy, to_tensor
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
@@ -118,13 +119,12 @@ class _NormalizationMixin:
|
||||
)
|
||||
self.features = reconstructed
|
||||
|
||||
if self.norm_map:
|
||||
# if keys are strings (JSON), rebuild enum map
|
||||
if all(isinstance(k, str) for k in self.norm_map.keys()):
|
||||
reconstructed = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed
|
||||
# if keys are strings (JSON), rebuild enum map
|
||||
if self.norm_map and all(isinstance(k, str) for k in self.norm_map):
|
||||
reconstructed = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed
|
||||
|
||||
# Convert stats to tensors and move to the target device once during initialization.
|
||||
self.stats = self.stats or {}
|
||||
@@ -272,7 +272,7 @@ class _NormalizationMixin:
|
||||
Returns:
|
||||
The transformed action tensor.
|
||||
"""
|
||||
processed_action = self._apply_transform(action, "action", FeatureType.ACTION, inverse=inverse)
|
||||
processed_action = self._apply_transform(action, ACTION, FeatureType.ACTION, inverse=inverse)
|
||||
return processed_action
|
||||
|
||||
def _apply_transform(
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
@@ -152,7 +152,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
# Build a new features mapping keyed by the same FeatureType buckets
|
||||
# We assume callers already placed features in the correct FeatureType.
|
||||
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features.keys()}
|
||||
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features}
|
||||
|
||||
exact_pairs = {
|
||||
"pixels": OBS_IMAGE,
|
||||
@@ -171,7 +171,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1)
|
||||
for old_prefix, new_prefix in prefix_pairs.items():
|
||||
prefixed_old = f"observation.{old_prefix}"
|
||||
prefixed_old = f"{OBS_STR}.{old_prefix}"
|
||||
if key.startswith(prefixed_old):
|
||||
suffix = key[len(prefixed_old) :]
|
||||
new_key = f"{new_prefix}{suffix}"
|
||||
@@ -191,7 +191,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# Exact-name rules (pixels, environment_state, agent_pos)
|
||||
for old, new in exact_pairs.items():
|
||||
if key == old or key == f"observation.{old}":
|
||||
if key == old or key == f"{OBS_STR}.{old}":
|
||||
new_key = new
|
||||
new_features[src_ft][new_key] = feat
|
||||
handled = True
|
||||
|
||||
@@ -422,7 +422,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||
"""
|
||||
if save_directory is None:
|
||||
# Use default directory in HF_LEROBOT_HOME
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
save_directory = HF_LEROBOT_HOME / "processors" / sanitized_name
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -23,7 +24,7 @@ class RobotActionToPolicyActionProcessorStep(ActionProcessorStep):
|
||||
return asdict(self)
|
||||
|
||||
def transform_features(self, features):
|
||||
features[PipelineFeatureType.ACTION]["action"] = PolicyFeature(
|
||||
features[PipelineFeatureType.ACTION][ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(len(self.motor_names),)
|
||||
)
|
||||
return features
|
||||
|
||||
@@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
|
||||
@@ -35,7 +35,7 @@ gamepad to take control of the robot during training. Initially intervene freque
|
||||
reduce interventions as the policy improves.
|
||||
|
||||
**WORKFLOW**:
|
||||
1. Determine robot workspace bounds using `find_joint_limits.py`
|
||||
1. Determine robot workspace bounds using `lerobot-find-joint-limits`
|
||||
2. Record demonstrations with `gym_manipulator.py` in record mode
|
||||
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
|
||||
4. Start the learner server with the training configuration
|
||||
@@ -63,6 +63,8 @@ from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
@@ -75,8 +77,6 @@ from lerobot.transport.utils import (
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.queue import get_last_item_from_queue
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.transition import (
|
||||
@@ -97,8 +97,6 @@ from .gym_manipulator import (
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
# Main entry point
|
||||
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, REWARD
|
||||
from lerobot.utils.transition import Transition
|
||||
|
||||
|
||||
@@ -175,7 +176,7 @@ class ReplayBuffer:
|
||||
self.complementary_info[key] = torch.empty(
|
||||
(self.capacity, *value_shape), device=self.storage_device
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
elif isinstance(value, (int | float)):
|
||||
# Handle scalar values similar to reward
|
||||
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
|
||||
else:
|
||||
@@ -222,7 +223,7 @@ class ReplayBuffer:
|
||||
value = complementary_info[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
|
||||
elif isinstance(value, (int, float)):
|
||||
elif isinstance(value, (int | float)):
|
||||
self.complementary_info[key][self.position] = value
|
||||
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
@@ -240,7 +241,7 @@ class ReplayBuffer:
|
||||
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
||||
|
||||
# Identify image keys that need augmentation
|
||||
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
|
||||
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
|
||||
|
||||
# Create batched state and next_state
|
||||
batch_state = {}
|
||||
@@ -466,7 +467,7 @@ class ReplayBuffer:
|
||||
if list_transition:
|
||||
first_transition = list_transition[0]
|
||||
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
||||
first_action = first_transition["action"].to(device)
|
||||
first_action = first_transition[ACTION].to(device)
|
||||
|
||||
# Get complementary info if available
|
||||
first_complementary_info = None
|
||||
@@ -491,7 +492,7 @@ class ReplayBuffer:
|
||||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(storage_device)
|
||||
|
||||
action = data["action"]
|
||||
action = data[ACTION]
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
@@ -529,12 +530,12 @@ class ReplayBuffer:
|
||||
|
||||
# Add "action"
|
||||
sample_action = self.actions[0]
|
||||
act_info = guess_feature_info(t=sample_action, name="action")
|
||||
features["action"] = act_info
|
||||
act_info = guess_feature_info(t=sample_action, name=ACTION)
|
||||
features[ACTION] = act_info
|
||||
|
||||
# Add "reward" and "done"
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,)}
|
||||
features[REWARD] = {"dtype": "float32", "shape": (1,)}
|
||||
features[DONE] = {"dtype": "bool", "shape": (1,)}
|
||||
|
||||
# Add state keys
|
||||
for key in self.states:
|
||||
@@ -576,9 +577,9 @@ class ReplayBuffer:
|
||||
frame_dict[key] = self.states[key][actual_idx].cpu()
|
||||
|
||||
# Fill action, reward, done
|
||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
frame_dict[ACTION] = self.actions[actual_idx].cpu()
|
||||
frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict[DONE] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
frame_dict["task"] = task_name
|
||||
|
||||
# Add complementary_info if available
|
||||
@@ -647,7 +648,7 @@ class ReplayBuffer:
|
||||
|
||||
# Check if the dataset has "next.done" key
|
||||
sample = dataset[0]
|
||||
has_done_key = "next.done" in sample
|
||||
has_done_key = DONE in sample
|
||||
|
||||
# Check for complementary_info keys
|
||||
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
|
||||
@@ -667,14 +668,14 @@ class ReplayBuffer:
|
||||
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 2) Action -----
|
||||
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||
action = current_sample[ACTION].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
reward = float(current_sample[REWARD].item()) # ensure float
|
||||
|
||||
# Determine done flag - use next.done if available, otherwise infer from episode boundaries
|
||||
if has_done_key:
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
done = bool(current_sample[DONE].item()) # ensure bool
|
||||
else:
|
||||
# If this is the last frame or if next frame is in a different episode, mark as done
|
||||
done = False
|
||||
@@ -787,8 +788,8 @@ def concatenate_batch_transitions(
|
||||
}
|
||||
|
||||
# Concatenate basic fields
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
left_batch_transitions[ACTION] = torch.cat(
|
||||
[left_batch_transitions[ACTION], right_batch_transition[ACTION]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
@@ -25,6 +25,7 @@ import torchvision.transforms.functional as F # type: ignore # noqa: N812
|
||||
from tqdm import tqdm # type: ignore
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import DONE, REWARD
|
||||
|
||||
|
||||
def select_rect_roi(img):
|
||||
@@ -159,7 +160,7 @@ def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
|
||||
return image_dict
|
||||
|
||||
|
||||
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
def convert_lerobot_dataset_to_cropped_lerobot_dataset(
|
||||
original_dataset: LeRobotDataset,
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]],
|
||||
new_repo_id: str,
|
||||
@@ -189,7 +190,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
# 1. Create a new (empty) LeRobotDataset for writing.
|
||||
new_dataset = LeRobotDataset.create(
|
||||
repo_id=new_repo_id,
|
||||
fps=original_dataset.fps,
|
||||
fps=int(original_dataset.fps),
|
||||
root=new_dataset_root,
|
||||
robot_type=original_dataset.meta.robot_type,
|
||||
features=original_dataset.meta.info["features"],
|
||||
@@ -212,7 +213,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
for key, value in frame.items():
|
||||
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"):
|
||||
continue
|
||||
if key in ("next.done", "next.reward"):
|
||||
if key in (DONE, REWARD):
|
||||
# if not isinstance(value, str) and len(value.shape) == 0:
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
@@ -274,6 +275,12 @@ if __name__ == "__main__":
|
||||
default="",
|
||||
help="The natural language task to describe the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--new-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The repository id for the new cropped and resized dataset. If not provided, it defaults to `repo_id` + '_cropped_resized'.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
|
||||
@@ -293,10 +300,16 @@ if __name__ == "__main__":
|
||||
for key, roi in rois.items():
|
||||
print(f"{key}: {roi}")
|
||||
|
||||
new_repo_id = args.repo_id + "_cropped_resized"
|
||||
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||
new_repo_id = args.new_repo_id if args.new_repo_id else args.repo_id + "_cropped_resized"
|
||||
|
||||
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
if args.new_repo_id:
|
||||
new_dataset_name = args.new_repo_id.split("/")[-1]
|
||||
# Parent 1: HF user, Parent 2: HF LeRobot Home
|
||||
new_dataset_root = dataset.root.parent.parent / new_dataset_name
|
||||
else:
|
||||
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||
|
||||
cropped_resized_dataset = convert_lerobot_dataset_to_cropped_lerobot_dataset(
|
||||
original_dataset=dataset,
|
||||
crop_params_dict=rois,
|
||||
new_repo_id=new_repo_id,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user