mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
fix bugs when testing on hardware
This commit is contained in:
@@ -40,17 +40,16 @@ Here's a complete training command for training Multi-Task DiT on your dataset:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-train \
|
lerobot-train \
|
||||||
--dataset.repo_id=$DATASET_ID \
|
--dataset.repo_id={{MY_DATASET_ID}} \
|
||||||
--output_dir=$OUTPUT_DIR \
|
--output_dir={{MY_OUTPUT_DIR}} \
|
||||||
--job_name=$JOB_NAME \
|
|
||||||
--policy.type=multi_task_dit \
|
--policy.type=multi_task_dit \
|
||||||
--policy.device=cuda \
|
--policy.device=cuda \
|
||||||
|
--policy.repo_id={{MY_REPO_ID}}
|
||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--steps=5000 \
|
--steps=5000 \
|
||||||
--save_freq=500 \
|
--save_freq=500 \
|
||||||
--log_freq=100 \
|
--log_freq=100 \
|
||||||
--wandb.enable=true \
|
--wandb.enable=true
|
||||||
--policy.repo_id=$REPO_ID
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency)
|
### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency)
|
||||||
@@ -59,15 +58,15 @@ For reliable performance, start with these suggested default hyperparameters:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-train \
|
lerobot-train \
|
||||||
--dataset.repo_id=$DATASET_ID \
|
--dataset.repo_id={{MY_DATASET_ID}} \
|
||||||
--output_dir=$OUTPUT_DIR \
|
--output_dir={{MY_OUTPUT_DIR}} \
|
||||||
--job_name=$JOB_NAME \
|
|
||||||
--policy.type=multi_task_dit \
|
--policy.type=multi_task_dit \
|
||||||
--policy.device=cuda \
|
--policy.device=cuda \
|
||||||
--batch_size=320 \
|
--batch_size=320 \
|
||||||
--steps=30000 \
|
--steps=30000 \
|
||||||
--policy.horizon=32 \
|
--policy.horizon=32 \
|
||||||
--policy.n_action_steps=24 \
|
--policy.n_action_steps=24 \
|
||||||
|
--policy.repo_id={{MY_REPO_ID}} \
|
||||||
--policy.objective=diffusion \
|
--policy.objective=diffusion \
|
||||||
--policy.noise_scheduler_type=DDPM \
|
--policy.noise_scheduler_type=DDPM \
|
||||||
--policy.num_train_timesteps=100 \
|
--policy.num_train_timesteps=100 \
|
||||||
@@ -263,8 +262,8 @@ Here's a complete example training on a custom dataset:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-train \
|
lerobot-train \
|
||||||
--dataset.repo_id=your_username/your_dataset \
|
--dataset.repo_id={{MY_DATASET_ID}} \
|
||||||
--output_dir=outputs/multitask_dit_training \
|
--output_dir={{MY_OUTPUT_DIR}} \
|
||||||
--policy.type=multi_task_dit \
|
--policy.type=multi_task_dit \
|
||||||
--policy.device=cuda \
|
--policy.device=cuda \
|
||||||
--batch_size=320 \
|
--batch_size=320 \
|
||||||
@@ -283,7 +282,7 @@ lerobot-train \
|
|||||||
--policy.image_crop_shape=[224,224] \
|
--policy.image_crop_shape=[224,224] \
|
||||||
--wandb.enable=true \
|
--wandb.enable=true \
|
||||||
--wandb.project=multitask_dit \
|
--wandb.project=multitask_dit \
|
||||||
--policy.repo_id=your_username/multitask_dit_policy
|
--policy.repo_id={{MY_REPO_ID}}
|
||||||
```
|
```
|
||||||
|
|
||||||
## References
|
## References
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import NormalizationMode
|
from lerobot.configs.types import NormalizationMode
|
||||||
@@ -37,7 +36,7 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
|||||||
n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
|
n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
|
||||||
|
|
||||||
# Objective Selection
|
# Objective Selection
|
||||||
objective: Literal["diffusion", "flow_matching"] = "diffusion"
|
objective: str = "diffusion" # "diffusion" or "flow_matching"
|
||||||
|
|
||||||
# --- Diffusion-specific (used when objective="diffusion") ---
|
# --- Diffusion-specific (used when objective="diffusion") ---
|
||||||
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
|
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
|
||||||
@@ -54,7 +53,7 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
|||||||
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
|
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
|
||||||
num_integration_steps: int = 100 # ODE integration steps at inference
|
num_integration_steps: int = 100 # ODE integration steps at inference
|
||||||
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
|
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
|
||||||
timestep_sampling_strategy: Literal["uniform", "beta"] = "beta"
|
timestep_sampling_strategy: str = "beta" # "uniform" or "beta"
|
||||||
|
|
||||||
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
|
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
|
||||||
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
|
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
|
||||||
@@ -112,6 +111,12 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
def _validate(self):
|
def _validate(self):
|
||||||
"""Validate configuration parameters."""
|
"""Validate configuration parameters."""
|
||||||
|
# Objective validation
|
||||||
|
if self.objective not in ["diffusion", "flow_matching"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'"
|
||||||
|
)
|
||||||
|
|
||||||
# Transformer validation
|
# Transformer validation
|
||||||
if self.hidden_dim <= 0:
|
if self.hidden_dim <= 0:
|
||||||
raise ValueError("hidden_dim must be positive")
|
raise ValueError("hidden_dim must be positive")
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ Supports both diffusion and flow matching objectives for action generation.
|
|||||||
|
|
||||||
References:
|
References:
|
||||||
- https://arxiv.org/abs/2507.05331
|
- https://arxiv.org/abs/2507.05331
|
||||||
|
- https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/
|
||||||
- https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy
|
- https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy
|
||||||
- https://brysonkjones.substack.com/p/dissecting-multitask-diffusion-transformer-policy
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|||||||
Reference in New Issue
Block a user