refactor: several fixes

This commit is contained in:
Steven Palma
2026-04-10 15:35:31 +02:00
parent e2381633cd
commit 882a6b0965
21 changed files with 407 additions and 356 deletions
+10 -9
View File
@@ -70,6 +70,9 @@ dependencies = [
"huggingface-hub>=1.0.0,<2.0.0",
# Environments
# NOTE: gymnasium is used in lerobot.envs (lerobot-train, lerobot-eval), policies/factory,
# and robots/unitree. Moving it to an optional extra would require import guards across many
# tightly-coupled modules. Candidate for a future refactor to decouple envs from the core.
"gymnasium>=1.1.1,<2.0.0",
# Lightweight utilities
@@ -83,11 +86,11 @@ dependencies = [
# ── Feature-scoped extras ──────────────────────────────────
dataset = [
"datasets>=4.0.0,<5.0.0",
"av>=15.0.0,<16.0.0",
"lerobot[av-dep]",
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
"jsonlines>=4.0.0,<5.0.0",
]
train = [
training = [
"lerobot[dataset]",
"accelerate>=1.10.0,<2.0.0",
"wandb>=0.24.0,<0.25.0",
@@ -110,13 +113,12 @@ build = [
# lerobot-record, lerobot-replay, lerobot-calibrate, lerobot-teleoperate, etc.
robot = ["lerobot[dataset]", "lerobot[hardware]", "lerobot[viz]"]
# lerobot-eval
evaluation = ["av>=15.0.0,<16.0.0"]
# lerobot-train
training = ["lerobot[train]"]
evaluation = ["lerobot[av-dep]"]
# lerobot-dataset-viz, lerobot-imgtransform-viz
dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
# Common
av-dep = ["av>=15.0.0,<16.0.0"]
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.9.17"]
transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249
@@ -186,8 +188,8 @@ async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# Development
dev = ["lerobot[dataset]", "lerobot[train]", "lerobot[hardware]", "lerobot[viz]", "pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
test = ["lerobot[dataset]", "lerobot[train]", "lerobot[hardware]", "lerobot[viz]", "pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
dev = ["lerobot[dataset]", "lerobot[training]", "lerobot[hardware]", "lerobot[viz]", "pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
test = ["lerobot[dataset]", "lerobot[training]", "lerobot[hardware]", "lerobot[viz]", "pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
@@ -201,7 +203,7 @@ metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
all = [
# Feature-scoped extras
"lerobot[dataset]",
"lerobot[train]",
"lerobot[training]",
"lerobot[hardware]",
"lerobot[viz]",
"lerobot[build]",
@@ -294,7 +296,6 @@ ignore = [
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401", "F403", "E402"]
"src/lerobot/scripts/*" = ["E402"] # require_package gates before imports
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
[tool.ruff.lint.isort]
-23
View File
@@ -17,9 +17,7 @@ import contextlib
import importlib.resources
import json
import logging
from collections.abc import Iterator
from pathlib import Path
from typing import Any
import datasets
import numpy as np
@@ -281,27 +279,6 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
raise ForwardCompatibilityError(repo_id, min(upper_versions))
def cycle(iterable: Any) -> Iterator[Any]:
"""Create a dataloader-safe cyclical iterator.
This is an equivalent of `itertools.cycle` but is safe for use with
PyTorch DataLoaders with multiple workers.
See https://github.com/pytorch/pytorch/issues/23900 for details.
Args:
iterable: The iterable to cycle over.
Yields:
Items from the iterable, restarting from the beginning when exhausted.
"""
iterator = iter(iterable)
while True:
try:
yield next(iterator)
except StopIteration:
iterator = iter(iterable)
def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) -> None:
"""Create a branch on an existing Hugging Face repo.
@@ -29,19 +29,24 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor, nn
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import (
from lerobot.utils.import_utils import require_package
require_package("diffusers", extra="training")
from diffusers.schedulers.scheduling_ddim import DDIMScheduler # noqa: E402
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler # noqa: E402
from torch import Tensor, nn # noqa: E402
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig # noqa: E402
from lerobot.policies.pretrained import PreTrainedPolicy # noqa: E402
from lerobot.policies.utils import ( # noqa: E402
get_device_from_parameters,
get_dtype_from_parameters,
get_output_shape,
populate_queues,
)
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE # noqa: E402
class DiffusionPolicy(PreTrainedPolicy):
@@ -16,15 +16,20 @@
import torch
import torch.nn.functional as F # noqa: N812
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import (
from lerobot.utils.import_utils import require_package
require_package("diffusers", extra="training")
from diffusers import ConfigMixin, ModelMixin # noqa: E402
from diffusers.configuration_utils import register_to_config # noqa: E402
from diffusers.models.attention import Attention, FeedForward # noqa: E402
from diffusers.models.embeddings import ( # noqa: E402
SinusoidalPositionalEmbedding,
TimestepEmbedding,
Timesteps,
)
from torch import nn
from torch import nn # noqa: E402
class TimestepEncoder(nn.Module):
@@ -34,12 +34,17 @@ import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.utils.import_utils import _transformers_available
from lerobot.utils.import_utils import require_package
require_package("diffusers", extra="training")
from diffusers.schedulers.scheduling_ddim import DDIMScheduler # noqa: E402
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler # noqa: E402
from torch import Tensor # noqa: E402
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig # noqa: E402
from lerobot.utils.import_utils import _transformers_available # noqa: E402
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
@@ -47,9 +52,9 @@ if TYPE_CHECKING or _transformers_available:
else:
CLIPTextModel = None
CLIPVisionModel = None
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import (
from lerobot.policies.pretrained import PreTrainedPolicy # noqa: E402
from lerobot.policies.utils import populate_queues # noqa: E402
from lerobot.utils.constants import ( # noqa: E402
ACTION,
OBS_IMAGES,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -44,10 +44,10 @@ from huggingface_hub import HfApi
from requests import HTTPError
from tqdm import tqdm
from lerobot.datasets import LeRobotDataset
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION
from lerobot.datasets.io_utils import write_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.utils import init_logging
@@ -59,6 +59,7 @@ from datasets import Dataset, Features, Image
from huggingface_hub import HfApi, snapshot_download
from requests import HTTPError
from lerobot.datasets import LeRobotDataset
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION
from lerobot.datasets.io_utils import (
@@ -72,7 +73,6 @@ from lerobot.datasets.io_utils import (
write_stats,
write_tasks,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -82,12 +82,11 @@ from lerobot.datasets.utils import (
LEGACY_EPISODES_PATH,
LEGACY_EPISODES_STATS_PATH,
LEGACY_TASKS_PATH,
flatten_dict,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.utils import init_logging
from lerobot.utils.utils import flatten_dict, init_logging
V21 = "v2.1"
V30 = "v3.0"
+1 -1
View File
@@ -70,7 +70,7 @@ import torch
import torch.utils.data
import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
from lerobot.utils.utils import init_logging
+1 -1
View File
@@ -178,6 +178,7 @@ from pathlib import Path
import draccus
from lerobot.configs import parser
from lerobot.datasets import LeRobotDataset
from lerobot.datasets.dataset_tools import (
convert_image_to_video_dataset,
delete_episodes,
@@ -187,7 +188,6 @@ from lerobot.datasets.dataset_tools import (
remove_feature,
split_dataset,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.utils import init_logging
@@ -36,7 +36,7 @@ import draccus
from torchvision.transforms import ToPILImage
from lerobot.configs.default import DatasetConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets import LeRobotDataset
from lerobot.transforms import (
ImageTransforms,
ImageTransformsConfig,
+1 -1
View File
@@ -85,9 +85,9 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets import LeRobotDataset
from lerobot.datasets.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
+1 -1
View File
@@ -46,7 +46,7 @@ from pathlib import Path
from pprint import pformat
from lerobot.configs import parser
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets import LeRobotDataset
from lerobot.processor import (
make_default_robot_action_processor,
)
+22 -22
View File
@@ -24,35 +24,35 @@ from lerobot.utils.import_utils import require_package
require_package("accelerate", extra="training")
import torch
from accelerate import Accelerator
from termcolor import colored
from torch.optim import Optimizer
from tqdm import tqdm
import torch # noqa: E402
from accelerate import Accelerator # noqa: E402
from termcolor import colored # noqa: E402
from torch.optim import Optimizer # noqa: E402
from tqdm import tqdm # noqa: E402
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.utils import close_envs
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.rl.wandb_utils import WandBLogger
from lerobot.scripts.lerobot_eval import eval_policy_all
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import (
from lerobot.configs import parser # noqa: E402
from lerobot.configs.train import TrainPipelineConfig # noqa: E402
from lerobot.datasets import EpisodeAwareSampler # noqa: E402
from lerobot.datasets.factory import make_dataset # noqa: E402
from lerobot.envs.factory import make_env, make_env_pre_post_processors # noqa: E402
from lerobot.envs.utils import close_envs # noqa: E402
from lerobot.optim.factory import make_optimizer_and_scheduler # noqa: E402
from lerobot.policies.factory import make_policy, make_pre_post_processors # noqa: E402
from lerobot.policies.pretrained import PreTrainedPolicy # noqa: E402
from lerobot.rl.wandb_utils import WandBLogger # noqa: E402
from lerobot.scripts.lerobot_eval import eval_policy_all # noqa: E402
from lerobot.utils.import_utils import register_third_party_plugins # noqa: E402
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker # noqa: E402
from lerobot.utils.random_utils import set_seed # noqa: E402
from lerobot.utils.train_utils import ( # noqa: E402
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.utils.utils import (
from lerobot.utils.utils import ( # noqa: E402
cycle,
format_big_number,
has_method,
init_logging,
@@ -62,7 +62,7 @@ else:
from lerobot.configs import parser
from lerobot.configs.types import NormalizationMode
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets import LeRobotDataset
from lerobot.utils.constants import ACTION, OBS_STATE
+16 -245
View File
@@ -1,6 +1,4 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,248 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import Any
import torch
from torchvision.transforms import v2
from torchvision.transforms.v2 import (
Transform,
functional as F, # noqa: N812
from lerobot.transforms.transforms import (
ImageTransformConfig,
ImageTransforms,
ImageTransformsConfig,
RandomSubsetApply,
SharpnessJitter,
make_transform_from_config,
)
class RandomSubsetApply(Transform):
"""Apply a random subset of N transformations from a list of transformations.
Args:
transforms: list of transformations.
p: represents the multinomial probabilities (with no replacement) used for sampling the transform.
If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms
have the same probability.
n_subset: number of transformations to apply. If ``None``, all transforms are applied.
Must be in [1, len(transforms)].
random_order: apply transformations in a random order.
"""
def __init__(
self,
transforms: Sequence[Callable],
p: list[float] | None = None,
n_subset: int | None = None,
random_order: bool = False,
) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if p is None:
p = [1] * len(transforms)
elif len(p) != len(transforms):
raise ValueError(
f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
)
if n_subset is None:
n_subset = len(transforms)
elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None")
elif not (1 <= n_subset <= len(transforms)):
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
self.transforms = transforms
total = sum(p)
self.p = [prob / total for prob in p]
self.n_subset = n_subset
self.random_order = random_order
self.selected_transforms = None
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
if not self.random_order:
selected_indices = selected_indices.sort().values
self.selected_transforms = [self.transforms[i] for i in selected_indices]
for transform in self.selected_transforms:
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs
def extra_repr(self) -> str:
return (
f"transforms={self.transforms}, "
f"p={self.p}, "
f"n_subset={self.n_subset}, "
f"random_order={self.random_order}"
)
class SharpnessJitter(Transform):
"""Randomly change the sharpness of an image or video.
Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
While v2.RandomAdjustSharpness applies with a given probability a fixed sharpness_factor to an image,
SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of
augmentations as a result.
A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
by a factor of 2.
If the input is a :class:`torch.Tensor`,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from
[max(0, 1 - sharpness), 1 + sharpness] or the given
[min, max]. Should be non negative numbers.
"""
def __init__(self, sharpness: float | Sequence[float]) -> None:
super().__init__()
self.sharpness = self._check_input(sharpness)
def _check_input(self, sharpness):
if isinstance(sharpness, (int | float)):
if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.")
sharpness = [1.0 - sharpness, 1.0 + sharpness]
sharpness[0] = max(sharpness[0], 0.0)
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
sharpness = [float(v) for v in sharpness]
else:
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
if not 0.0 <= sharpness[0] <= sharpness[1]:
raise ValueError(f"sharpness values should be between (0., inf), but got {sharpness}.")
return float(sharpness[0]), float(sharpness[1])
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
return {"sharpness_factor": sharpness_factor}
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
sharpness_factor = params["sharpness_factor"]
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
@dataclass
class ImageTransformConfig:
"""
For each transform, the following parameters are available:
weight: This represents the multinomial probability (with no replacement)
used for sampling the transform. If the sum of the weights is not 1,
they will be normalized.
type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
custom transform defined here.
kwargs: Lower & upper bound respectively used for sampling the transform's parameter
(following uniform distribution) when it's applied.
"""
weight: float = 1.0
type: str = "Identity"
kwargs: dict[str, Any] = field(default_factory=dict)
@dataclass
class ImageTransformsConfig:
"""
These transforms are all using standard torchvision.transforms.v2
You can find out how these transformations affect images here:
https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
We use a custom RandomSubsetApply container to sample them.
"""
# Set this flag to `true` to enable transforms during training
enable: bool = False
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
# It's an integer in the interval [1, number_of_available_transforms].
max_num_transforms: int = 3
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: bool = False
tfs: dict[str, ImageTransformConfig] = field(
default_factory=lambda: {
"brightness": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"brightness": (0.8, 1.2)},
),
"contrast": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"contrast": (0.8, 1.2)},
),
"saturation": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"saturation": (0.5, 1.5)},
),
"hue": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"hue": (-0.05, 0.05)},
),
"sharpness": ImageTransformConfig(
weight=1.0,
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
"affine": ImageTransformConfig(
weight=1.0,
type="RandomAffine",
kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)},
),
}
)
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
transform_cls = getattr(v2, cfg.type, None)
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
return transform_cls(**cfg.kwargs)
raise ValueError(
f"Transform '{cfg.type}' is not valid. It must be a class in "
f"torchvision.transforms.v2 or 'SharpnessJitter'."
)
class ImageTransforms(Transform):
"""A class to compose image transforms based on configuration."""
def __init__(self, cfg: ImageTransformsConfig) -> None:
super().__init__()
self._cfg = cfg
self.weights = []
self.transforms = {}
for tf_name, tf_cfg in cfg.tfs.items():
if tf_cfg.weight <= 0.0:
continue
self.transforms[tf_name] = make_transform_from_config(tf_cfg)
self.weights.append(tf_cfg.weight)
n_subset = min(len(self.transforms), cfg.max_num_transforms)
if n_subset == 0 or not cfg.enable:
self.tf = v2.Identity()
else:
self.tf = RandomSubsetApply(
transforms=list(self.transforms.values()),
p=self.weights,
n_subset=n_subset,
random_order=cfg.random_order,
)
def forward(self, *inputs: Any) -> Any:
return self.tf(*inputs)
__all__ = [
"ImageTransformConfig",
"ImageTransforms",
"ImageTransformsConfig",
"RandomSubsetApply",
"SharpnessJitter",
"make_transform_from_config",
]
+260
View File
@@ -0,0 +1,260 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import Any
import torch
from torchvision.transforms import v2
from torchvision.transforms.v2 import (
Transform,
functional as F, # noqa: N812
)
class RandomSubsetApply(Transform):
"""Apply a random subset of N transformations from a list of transformations.
Args:
transforms: list of transformations.
p: represents the multinomial probabilities (with no replacement) used for sampling the transform.
If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms
have the same probability.
n_subset: number of transformations to apply. If ``None``, all transforms are applied.
Must be in [1, len(transforms)].
random_order: apply transformations in a random order.
"""
def __init__(
self,
transforms: Sequence[Callable],
p: list[float] | None = None,
n_subset: int | None = None,
random_order: bool = False,
) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if p is None:
p = [1] * len(transforms)
elif len(p) != len(transforms):
raise ValueError(
f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
)
if n_subset is None:
n_subset = len(transforms)
elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None")
elif not (1 <= n_subset <= len(transforms)):
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
self.transforms = transforms
total = sum(p)
self.p = [prob / total for prob in p]
self.n_subset = n_subset
self.random_order = random_order
self.selected_transforms = None
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
if not self.random_order:
selected_indices = selected_indices.sort().values
self.selected_transforms = [self.transforms[i] for i in selected_indices]
for transform in self.selected_transforms:
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs
def extra_repr(self) -> str:
return (
f"transforms={self.transforms}, "
f"p={self.p}, "
f"n_subset={self.n_subset}, "
f"random_order={self.random_order}"
)
class SharpnessJitter(Transform):
"""Randomly change the sharpness of an image or video.
Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
While v2.RandomAdjustSharpness applies with a given probability a fixed sharpness_factor to an image,
SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of
augmentations as a result.
A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
by a factor of 2.
If the input is a :class:`torch.Tensor`,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from
[max(0, 1 - sharpness), 1 + sharpness] or the given
[min, max]. Should be non negative numbers.
"""
def __init__(self, sharpness: float | Sequence[float]) -> None:
super().__init__()
self.sharpness = self._check_input(sharpness)
def _check_input(self, sharpness):
if isinstance(sharpness, (int | float)):
if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.")
sharpness = [1.0 - sharpness, 1.0 + sharpness]
sharpness[0] = max(sharpness[0], 0.0)
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
sharpness = [float(v) for v in sharpness]
else:
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
if not 0.0 <= sharpness[0] <= sharpness[1]:
raise ValueError(f"sharpness values should be between (0., inf), but got {sharpness}.")
return float(sharpness[0]), float(sharpness[1])
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
return {"sharpness_factor": sharpness_factor}
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
sharpness_factor = params["sharpness_factor"]
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
@dataclass
class ImageTransformConfig:
"""
For each transform, the following parameters are available:
weight: This represents the multinomial probability (with no replacement)
used for sampling the transform. If the sum of the weights is not 1,
they will be normalized.
type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
custom transform defined here.
kwargs: Lower & upper bound respectively used for sampling the transform's parameter
(following uniform distribution) when it's applied.
"""
weight: float = 1.0
type: str = "Identity"
kwargs: dict[str, Any] = field(default_factory=dict)
@dataclass
class ImageTransformsConfig:
"""
These transforms are all using standard torchvision.transforms.v2
You can find out how these transformations affect images here:
https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
We use a custom RandomSubsetApply container to sample them.
"""
# Set this flag to `true` to enable transforms during training
enable: bool = False
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
# It's an integer in the interval [1, number_of_available_transforms].
max_num_transforms: int = 3
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: bool = False
tfs: dict[str, ImageTransformConfig] = field(
default_factory=lambda: {
"brightness": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"brightness": (0.8, 1.2)},
),
"contrast": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"contrast": (0.8, 1.2)},
),
"saturation": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"saturation": (0.5, 1.5)},
),
"hue": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"hue": (-0.05, 0.05)},
),
"sharpness": ImageTransformConfig(
weight=1.0,
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
"affine": ImageTransformConfig(
weight=1.0,
type="RandomAffine",
kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)},
),
}
)
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
transform_cls = getattr(v2, cfg.type, None)
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
return transform_cls(**cfg.kwargs)
raise ValueError(
f"Transform '{cfg.type}' is not valid. It must be a class in "
f"torchvision.transforms.v2 or 'SharpnessJitter'."
)
class ImageTransforms(Transform):
"""A class to compose image transforms based on configuration."""
def __init__(self, cfg: ImageTransformsConfig) -> None:
super().__init__()
self._cfg = cfg
self.weights = []
self.transforms = {}
for tf_name, tf_cfg in cfg.tfs.items():
if tf_cfg.weight <= 0.0:
continue
self.transforms[tf_name] = make_transform_from_config(tf_cfg)
self.weights.append(tf_cfg.weight)
n_subset = min(len(self.transforms), cfg.max_num_transforms)
if n_subset == 0 or not cfg.enable:
self.tf = v2.Identity()
else:
self.tf = RandomSubsetApply(
transforms=list(self.transforms.values()),
p=self.weights,
n_subset=n_subset,
random_order=cfg.random_order,
)
def forward(self, *inputs: Any) -> Any:
return self.tf(*inputs)
+2 -1
View File
@@ -70,10 +70,11 @@ def is_package_available(
def get_safe_default_codec():
logger = logging.getLogger(__name__)
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logging.warning(
logger.warning(
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
)
return "pyav"
+15 -4
View File
@@ -14,9 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
@@ -58,10 +61,18 @@ def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None:
import av
with av.open(str(video_path), mode="w") as container:
height, width = stacked_frames[0].shape[:2]
# Ensure dimensions are even for yuv420p compatibility
height = height if height % 2 == 0 else height - 1
width = width if width % 2 == 0 else width - 1
orig_height, orig_width = stacked_frames[0].shape[:2]
# yuv420p requires even dimensions; crop by one pixel if needed
height = orig_height if orig_height % 2 == 0 else orig_height - 1
width = orig_width if orig_width % 2 == 0 else orig_width - 1
if height != orig_height or width != orig_width:
logger.warning(
"Frame dimensions %dx%d are not even; cropping to %dx%d for yuv420p compatibility.",
orig_width,
orig_height,
width,
height,
)
stream = container.add_stream("libx264", rate=fps)
stream.width = width
stream.height = height
+23 -1
View File
@@ -22,11 +22,12 @@ import select
import subprocess
import sys
import time
from collections.abc import Iterator
from copy import copy, deepcopy
from datetime import datetime
from pathlib import Path
from statistics import mean
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import numpy as np
@@ -252,6 +253,27 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict
def cycle(iterable: Any) -> Iterator[Any]:
"""Create a dataloader-safe cyclical iterator.
This is an equivalent of `itertools.cycle` but is safe for use with
PyTorch DataLoaders with multiple workers.
See https://github.com/pytorch/pytorch/issues/23900 for details.
Args:
iterable: The iterable to cycle over.
Yields:
Items from the iterable, restarting from the beginning when exhausted.
"""
iterator = iter(iterable)
while True:
try:
yield next(iterator)
except StopIteration:
iterator = iter(iterable)
class SuppressProgressBars:
"""
Context manager to suppress progress bars.
+1 -1
View File
@@ -29,7 +29,6 @@ from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import close_envs, preprocess_observation
from lerobot.optim.factory import make_optimizer_and_scheduler
@@ -46,6 +45,7 @@ from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.random_utils import seeded_context
from lerobot.utils.utils import cycle
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
Generated
+14 -20
View File
@@ -2275,6 +2275,9 @@ async = [
{ name = "matplotlib" },
{ name = "protobuf" },
]
av-dep = [
{ name = "av" },
]
build = [
{ name = "cmake" },
{ name = "setuptools" },
@@ -2483,15 +2486,6 @@ test = [
{ name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
{ name = "wandb" },
]
train = [
{ name = "accelerate" },
{ name = "av" },
{ name = "datasets" },
{ name = "diffusers" },
{ name = "jsonlines" },
{ name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
{ name = "wandb" },
]
training = [
{ name = "accelerate" },
{ name = "av" },
@@ -2534,16 +2528,15 @@ xvla = [
[package.metadata]
requires-dist = [
{ name = "accelerate", marker = "extra == 'smolvla'", specifier = ">=1.7.0,<2.0.0" },
{ name = "accelerate", marker = "extra == 'train'", specifier = ">=1.10.0,<2.0.0" },
{ name = "av", marker = "extra == 'dataset'", specifier = ">=15.0.0,<16.0.0" },
{ name = "av", marker = "extra == 'evaluation'", specifier = ">=15.0.0,<16.0.0" },
{ name = "accelerate", marker = "extra == 'training'", specifier = ">=1.10.0,<2.0.0" },
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
{ name = "cmake", marker = "extra == 'build'", specifier = ">=3.29.0.1,<4.2.0" },
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
{ name = "deepdiff", marker = "extra == 'hardware'", specifier = ">=7.0.1,<9.0.0" },
{ name = "diffusers", marker = "extra == 'train'", specifier = ">=0.27.2,<0.36.0" },
{ name = "diffusers", marker = "extra == 'training'", specifier = ">=0.27.2,<0.36.0" },
{ name = "dm-tree", marker = "extra == 'groot'", specifier = ">=0.1.8,<1.0.0" },
{ name = "draccus", specifier = "==0.10.0" },
{ name = "dynamixel-sdk", marker = "extra == 'dynamixel'", specifier = ">=3.7.31,<3.9.0" },
@@ -2565,6 +2558,8 @@ requires-dist = [
{ name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
{ name = "lerobot", extras = ["aloha"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["async"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["av-dep"], marker = "extra == 'dataset'" },
{ name = "lerobot", extras = ["av-dep"], marker = "extra == 'evaluation'" },
{ name = "lerobot", extras = ["build"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["can-dep"], marker = "extra == 'damiao'" },
{ name = "lerobot", extras = ["can-dep"], marker = "extra == 'robstride'" },
@@ -2578,7 +2573,7 @@ requires-dist = [
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'pusht'" },
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'robot'" },
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'test'" },
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'train'" },
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'training'" },
{ name = "lerobot", extras = ["dev"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
@@ -2625,10 +2620,9 @@ requires-dist = [
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["test"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["train"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["train"], marker = "extra == 'dev'" },
{ name = "lerobot", extras = ["train"], marker = "extra == 'test'" },
{ name = "lerobot", extras = ["train"], marker = "extra == 'training'" },
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["training"], marker = "extra == 'dev'" },
{ name = "lerobot", extras = ["training"], marker = "extra == 'test'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" },
@@ -2694,9 +2688,9 @@ requires-dist = [
{ name = "torchdiffeq", marker = "extra == 'wallx'", specifier = ">=0.2.4,<0.3.0" },
{ name = "torchvision", specifier = ">=0.22.0,<0.26.0" },
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = "==5.3.0" },
{ name = "wandb", marker = "extra == 'train'", specifier = ">=0.24.0,<0.25.0" },
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
]
provides-extras = ["dataset", "train", "hardware", "viz", "build", "robot", "evaluation", "training", "dataset-viz", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "qwen-vl-utils-dep", "matplotlib-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "hilserl", "async", "peft", "dev", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
provides-extras = ["dataset", "training", "hardware", "viz", "build", "robot", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "qwen-vl-utils-dep", "matplotlib-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "hilserl", "async", "peft", "dev", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
[[package]]
name = "librt"