Compare commits

..

4 Commits

Author SHA1 Message Date
Caroline Pascal 8fff0fde7c chore(docstrings): fixing deprecated root argument description in LeRobotDataset class (#3035)
* chore(docstrings): fixing deprecated `root` argument docstrings in LeRobotDataset class

* chore(draccus): updating draccus CLI help

* chore(revert): reverting changes in lerobot_dataset_viz.py

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-27 18:22:44 +01:00
Pepijn 04de496547 fix(logging): avoid double-counting samples across processes (#3045) 2026-02-27 17:45:19 +01:00
Khalil Meftah baf9b50365 Fix(diffusion): enforce no-crop behavior when crop_ratio=1.0 (#3046)
* refactor(diffusion): replace crop_shape with resize_shape and crop_ratio

* fix(diffusion): address review feedback on resize/crop backward compat

* test: regenerate diffusion artifacts for updated default config

* fix: disable crop when resize path uses crop_ratio=1.0

---------

Co-authored-by: starlitxiling <1754165401@qq.com>
2026-02-27 17:44:53 +01:00
Jade Choghari a0fdbf037a feat(policies): add Smolvla torch compile support (#3043)
* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* pre-commit run

* Add torch.compile for smolvla

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* Add torch.compile for smolvla

Add model compilation option for improved performance.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* first

---------

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>
Co-authored-by: Aoqun Jin <aojiaojiao@foxmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-27 18:58:36 +03:00
19 changed files with 122 additions and 58 deletions
+1 -1
View File
@@ -57,7 +57,7 @@ class DatasetReplayConfig:
repo_id: str repo_id: str
# Episode to replay. # Episode to replay.
episode: int episode: int
# Root directory where the dataset will be stored (e.g. 'dataset/path'). # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None root: str | Path | None = None
# Limit the frames per second. By default, uses the policy fps. # Limit the frames per second. By default, uses the policy fps.
fps: int = 30 fps: int = 30
+1 -1
View File
@@ -27,7 +27,7 @@ class DatasetConfig:
# "dataset_index" into the returned item. The index mapping is made according to the order in which the # "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided. # datasets are provided.
repo_id: str repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | None = None root: str | None = None
episodes: list[int] | None = None episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
+5 -5
View File
@@ -664,11 +664,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
for the README). for the README).
Args: Args:
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset repo_id (str): This is the repo id that will be used to fetch the dataset.
will be stored under root/repo_id. root (Path | None, optional): Local directory where the dataset will be downloaded and
root (Path | None, optional): Local directory to use for downloading/writing files. You can also stored. If set, all dataset files will be stored directly under this path. If not set, the
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
'~/.cache/huggingface/lerobot'. HF_LEROBOT_HOME environment variable).
episodes (list[int] | None, optional): If specified, this will only load episodes specified by episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None. their episode_index in this list. Defaults to None.
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
@@ -55,10 +55,16 @@ class DiffusionConfig(PreTrainedConfig):
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images. vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
within the image size. If None, no cropping is done. backbone. If None, no resizing is done and the original image resolution is used.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
mode). (crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
this is computed automatically. Can also be set directly for legacy configs that use
crop-only (without resize). If None and no derivation applies, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center
crop in eval mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone. pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights. `None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone. use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
@@ -114,7 +120,9 @@ class DiffusionConfig(PreTrainedConfig):
# Architecture / modeling. # Architecture / modeling.
# Vision backbone. # Vision backbone.
vision_backbone: str = "resnet18" vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84) resize_shape: tuple[int, int] | None = None
crop_ratio: float = 1.0
crop_shape: tuple[int, int] | None = None
crop_is_random: bool = True crop_is_random: bool = True
pretrained_backbone_weights: str | None = None pretrained_backbone_weights: str | None = None
use_group_norm: bool = True use_group_norm: bool = True
@@ -175,6 +183,25 @@ class DiffusionConfig(PreTrainedConfig):
f"Got {self.noise_scheduler_type}." f"Got {self.noise_scheduler_type}."
) )
if self.resize_shape is not None and (
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
):
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
if not (0 < self.crop_ratio <= 1.0):
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
if self.resize_shape is not None:
if self.crop_ratio < 1.0:
self.crop_shape = (
int(self.resize_shape[0] * self.crop_ratio),
int(self.resize_shape[1] * self.crop_ratio),
)
else:
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
self.crop_shape = None
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
# Check that the horizon size and U-Net downsampling is compatible. # Check that the horizon size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage. # U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims) downsampling_factor = 2 ** len(self.down_dims)
@@ -202,13 +229,12 @@ class DiffusionConfig(PreTrainedConfig):
if len(self.image_features) == 0 and self.env_state_feature is None: if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.") raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None: if self.resize_shape is None and self.crop_shape is not None:
for key, image_ft in self.image_features.items(): for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]: if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError( raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for " f"for `crop_shape` and {image_ft.shape} for `{key}`."
f"`{key}`."
) )
# Check that all input images have the same shape. # Check that all input images have the same shape.
@@ -454,12 +454,18 @@ class DiffusionRgbEncoder(nn.Module):
def __init__(self, config: DiffusionConfig): def __init__(self, config: DiffusionConfig):
super().__init__() super().__init__()
# Set up optional preprocessing. # Set up optional preprocessing.
if config.crop_shape is not None: if config.resize_shape is not None:
self.resize = torchvision.transforms.Resize(config.resize_shape)
else:
self.resize = None
crop_shape = config.crop_shape
if crop_shape is not None:
self.do_crop = True self.do_crop = True
# Always use center crop for eval # Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
if config.crop_is_random: if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
else: else:
self.maybe_random_crop = self.center_crop self.maybe_random_crop = self.center_crop
else: else:
@@ -485,13 +491,16 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers. # Set up pooling and final layers.
# Use a dry run to get the feature map shape. # Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.image_features` and it should # The dummy shape mirrors the runtime preprocessing order: resize -> crop.
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.image_features`.
# Note: we have a check in the config class to make sure all images have the same shape. # Note: we have a check in the config class to make sure all images have the same shape.
images_shape = next(iter(config.image_features.values())).shape images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:] if config.crop_shape is not None:
dummy_shape_h_w = config.crop_shape
elif config.resize_shape is not None:
dummy_shape_h_w = config.resize_shape
else:
dummy_shape_h_w = images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w) dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
@@ -507,7 +516,10 @@ class DiffusionRgbEncoder(nn.Module):
Returns: Returns:
(B, D) image feature. (B, D) image feature.
""" """
# Preprocess: maybe crop (if it was set up in the __init__). # Preprocess: resize if configured, then crop if configured.
if self.resize is not None:
x = self.resize(x)
if self.do_crop: if self.do_crop:
if self.training: # noqa: SIM108 if self.training: # noqa: SIM108
x = self.maybe_random_crop(x) x = self.maybe_random_crop(x)
+1 -8
View File
@@ -995,14 +995,7 @@ class PI0Policy(PreTrainedPolicy):
# Initialize model without loading weights # Initialize model without loading weights
# Check if dataset_stats were provided in kwargs # Check if dataset_stats were provided in kwargs
if _transformers_available: model = cls(config, **kwargs)
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs)
# Now manually load and remap the state dict # Now manually load and remap the state dict
try: try:
+1 -8
View File
@@ -967,14 +967,7 @@ class PI05Policy(PreTrainedPolicy):
# Initialize model without loading weights # Initialize model without loading weights
# Check if dataset_stats were provided in kwargs # Check if dataset_stats were provided in kwargs
if _transformers_available: model = cls(config, **kwargs)
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs)
# Now manually load and remap the state dict # Now manually load and remap the state dict
try: try:
@@ -895,14 +895,7 @@ class PI0FastPolicy(PreTrainedPolicy):
# Initialize model without loading weights # Initialize model without loading weights
# Check if dataset_stats were provided in kwargs # Check if dataset_stats were provided in kwargs
if _transformers_available: model = cls(config, **kwargs)
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs)
# Now manually load and remap the state dict # Now manually load and remap the state dict
try: try:
@@ -106,6 +106,9 @@ class SmolVLAConfig(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration # Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None rtc_config: RTCConfig | None = None
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
@@ -593,6 +593,12 @@ class VLAFlowMatching(nn.Module):
self.prefix_length = self.config.prefix_length self.prefix_length = self.config.prefix_length
self.rtc_processor = rtc_processor self.rtc_processor = rtc_processor
# Compile model if requested
if config.compile_model:
torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
self.forward = torch.compile(self.forward, mode=config.compile_mode)
def _rtc_enabled(self): def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled return self.config.rtc_config is not None and self.config.rtc_config.enabled
+1 -1
View File
@@ -155,7 +155,7 @@ class DatasetRecordConfig:
repo_id: str repo_id: str
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.") # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str single_task: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None root: str | Path | None = None
# Limit the frames per second. # Limit the frames per second.
fps: int = 30 fps: int = 30
+1 -1
View File
@@ -80,7 +80,7 @@ class DatasetReplayConfig:
repo_id: str repo_id: str
# Episode to replay. # Episode to replay.
episode: int episode: int
# Root directory where the dataset will be stored (e.g. 'dataset/path'). # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None root: str | Path | None = None
# Limit the frames per second. By default, uses the policy fps. # Limit the frames per second. By default, uses the policy fps.
fps: int = 30 fps: int = 30
+2 -2
View File
@@ -380,10 +380,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
"dataloading_s": AverageMeter("data_s", ":.3f"), "dataloading_s": AverageMeter("data_s", ":.3f"),
} }
# Use effective batch size for proper epoch calculation in distributed training # Keep global batch size for logging; MetricsTracker handles world size internally.
effective_batch_size = cfg.batch_size * accelerator.num_processes effective_batch_size = cfg.batch_size * accelerator.num_processes
train_tracker = MetricsTracker( train_tracker = MetricsTracker(
effective_batch_size, cfg.batch_size,
dataset.num_frames, dataset.num_frames,
dataset.num_episodes, dataset.num_episodes,
train_metrics, train_metrics,
+4 -2
View File
@@ -104,9 +104,10 @@ class MetricsTracker:
self.metrics = metrics self.metrics = metrics
self.steps = initial_step self.steps = initial_step
world_size = accelerator.num_processes if accelerator else 1
# A sample is an (observation,action) pair, where observation and action # A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples. # can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
self.samples = self.steps * self._batch_size self.samples = self.steps * self._batch_size * world_size
self.episodes = self.samples / self._avg_samples_per_ep self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames self.epochs = self.samples / self._num_frames
self.accelerator = accelerator self.accelerator = accelerator
@@ -132,7 +133,8 @@ class MetricsTracker:
Updates metrics that depend on 'step' for one step. Updates metrics that depend on 'step' for one step.
""" """
self.steps += 1 self.steps += 1
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1) world_size = self.accelerator.num_processes if self.accelerator else 1
self.samples += self._batch_size * world_size
self.episodes = self.samples / self._avg_samples_per_ep self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames self.epochs = self.samples / self._num_frames
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8 oid sha256:54aecbc1af72a4cd5e9261492f5e7601890517516257aacdf2a0ffb3ce281f1b
size 992 size 992
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002 oid sha256:88a9c3775a2aa1e90a08850521970070a4fcf0f6b82aab43cd8ccc5cf77e0013
size 47424 size 47424
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5 oid sha256:91a2635e05a75fe187a5081504c5f35ce3417378813fa2deaf9ca4e8200e1819
size 68 size 68
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf oid sha256:645bff922ac7bea63ad018ebf77c303c0e4cd2c1c0dc5ef3192865281bef3dc6
size 47424 size 47424
+36
View File
@@ -24,6 +24,11 @@ def mock_metrics():
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")} return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
class MockAccelerator:
def __init__(self, num_processes: int):
self.num_processes = num_processes
def test_average_meter_initialization(): def test_average_meter_initialization():
meter = AverageMeter("loss", ":.2f") meter = AverageMeter("loss", ":.2f")
assert meter.name == "loss" assert meter.name == "loss"
@@ -82,6 +87,37 @@ def test_metrics_tracker_step(mock_metrics):
assert tracker.epochs == tracker.samples / 1000 assert tracker.epochs == tracker.samples / 1000
def test_metrics_tracker_initialization_with_accelerator(mock_metrics):
tracker = MetricsTracker(
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=mock_metrics,
initial_step=10,
accelerator=MockAccelerator(num_processes=2),
)
assert tracker.steps == 10
assert tracker.samples == 10 * 32 * 2
assert tracker.episodes == tracker.samples / (1000 / 50)
assert tracker.epochs == tracker.samples / 1000
def test_metrics_tracker_step_with_accelerator(mock_metrics):
tracker = MetricsTracker(
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=mock_metrics,
initial_step=5,
accelerator=MockAccelerator(num_processes=2),
)
tracker.step()
assert tracker.steps == 6
assert tracker.samples == (5 * 32 * 2) + (32 * 2)
assert tracker.episodes == tracker.samples / (1000 / 50)
assert tracker.epochs == tracker.samples / 1000
def test_metrics_tracker_getattr(mock_metrics): def test_metrics_tracker_getattr(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
assert tracker.loss == mock_metrics["loss"] assert tracker.loss == mock_metrics["loss"]