mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
refactor: more changes
This commit is contained in:
@@ -23,7 +23,6 @@ from typing import Any
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import PolicyFeature
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
|
|
||||||
|
|
||||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||||
from lerobot.policies import ( # noqa: F401
|
from lerobot.policies import ( # noqa: F401
|
||||||
@@ -36,6 +35,7 @@ from lerobot.policies import ( # noqa: F401
|
|||||||
)
|
)
|
||||||
from lerobot.robots.robot import Robot
|
from lerobot.robots.robot import Robot
|
||||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||||
|
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
Action = torch.Tensor
|
Action = torch.Tensor
|
||||||
|
|||||||
@@ -746,7 +746,7 @@ def save_annotations_to_dataset(
|
|||||||
dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse"
|
dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse"
|
||||||
):
|
):
|
||||||
"""Save annotations to LeRobot dataset parquet format."""
|
"""Save annotations to LeRobot dataset parquet format."""
|
||||||
from lerobot.datasets.io_utils import load_episodes
|
from lerobot.datasets import load_episodes
|
||||||
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH
|
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH
|
||||||
|
|
||||||
episodes_dataset = load_episodes(dataset_path)
|
episodes_dataset = load_episodes(dataset_path)
|
||||||
@@ -841,7 +841,7 @@ def generate_auto_sparse_annotations(
|
|||||||
|
|
||||||
def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]:
|
def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]:
|
||||||
"""Load annotations from LeRobot dataset parquet files."""
|
"""Load annotations from LeRobot dataset parquet files."""
|
||||||
from lerobot.datasets.io_utils import load_episodes
|
from lerobot.datasets import load_episodes
|
||||||
|
|
||||||
episodes_dataset = load_episodes(dataset_path)
|
episodes_dataset = load_episodes(dataset_path)
|
||||||
if not episodes_dataset or len(episodes_dataset) == 0:
|
if not episodes_dataset or len(episodes_dataset) == 0:
|
||||||
|
|||||||
@@ -20,10 +20,15 @@ from lerobot.utils.import_utils import require_package
|
|||||||
require_package("datasets", extra="dataset")
|
require_package("datasets", extra="dataset")
|
||||||
|
|
||||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||||
|
from lerobot.datasets.factory import make_dataset
|
||||||
|
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||||
|
from lerobot.datasets.io_utils import load_episodes, write_stats
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||||
|
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||||
|
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EpisodeAwareSampler",
|
"EpisodeAwareSampler",
|
||||||
@@ -31,4 +36,11 @@ __all__ = [
|
|||||||
"LeRobotDatasetMetadata",
|
"LeRobotDatasetMetadata",
|
||||||
"MultiLeRobotDataset",
|
"MultiLeRobotDataset",
|
||||||
"StreamingLeRobotDataset",
|
"StreamingLeRobotDataset",
|
||||||
|
"VideoEncodingManager",
|
||||||
|
"aggregate_pipeline_dataset_features",
|
||||||
|
"create_initial_features",
|
||||||
|
"load_episodes",
|
||||||
|
"make_dataset",
|
||||||
|
"safe_stop_image_writer",
|
||||||
|
"write_stats",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import pyarrow.parquet as pq
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from lerobot.datasets.compute_stats import aggregate_stats
|
from lerobot.datasets.compute_stats import aggregate_stats
|
||||||
from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info
|
from lerobot.datasets.feature_utils import create_empty_dataset_info
|
||||||
from lerobot.datasets.io_utils import (
|
from lerobot.datasets.io_utils import (
|
||||||
get_file_size_in_mb,
|
get_file_size_in_mb,
|
||||||
load_episodes,
|
load_episodes,
|
||||||
@@ -39,7 +39,6 @@ from lerobot.datasets.io_utils import (
|
|||||||
)
|
)
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
DEFAULT_EPISODES_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
DEFAULT_FEATURES,
|
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
check_version_compatibility,
|
check_version_compatibility,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
@@ -49,7 +48,8 @@ from lerobot.datasets.utils import (
|
|||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
)
|
)
|
||||||
from lerobot.datasets.video_utils import get_video_info
|
from lerobot.datasets.video_utils import get_video_info
|
||||||
from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
|
from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
|
||||||
|
from lerobot.utils.feature_utils import _validate_feature_names
|
||||||
|
|
||||||
CODEBASE_VERSION = "v3.0"
|
CODEBASE_VERSION = "v3.0"
|
||||||
|
|
||||||
|
|||||||
@@ -25,12 +25,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|||||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||||
from lerobot.transforms import ImageTransforms
|
from lerobot.transforms import ImageTransforms
|
||||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
|
||||||
|
|
||||||
IMAGENET_STATS = {
|
|
||||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
|
||||||
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_delta_timestamps(
|
def resolve_delta_timestamps(
|
||||||
|
|||||||
@@ -14,22 +14,19 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
DEFAULT_FEATURES,
|
|
||||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
)
|
)
|
||||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
|
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||||
|
|
||||||
|
|
||||||
@@ -71,199 +68,6 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
|||||||
return datasets.Features(hf_features)
|
return datasets.Features(hf_features)
|
||||||
|
|
||||||
|
|
||||||
def _validate_feature_names(features: dict[str, dict]) -> None:
|
|
||||||
"""Validate that feature names do not contain invalid characters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features (dict): The LeRobot features dictionary.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If any feature name contains '/'.
|
|
||||||
"""
|
|
||||||
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
|
|
||||||
if invalid_features:
|
|
||||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
|
|
||||||
|
|
||||||
|
|
||||||
def hw_to_dataset_features(
|
|
||||||
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
|
|
||||||
) -> dict[str, dict]:
|
|
||||||
"""Convert hardware-specific features to a LeRobot dataset feature dictionary.
|
|
||||||
|
|
||||||
This function takes a dictionary describing hardware outputs (like joint states
|
|
||||||
or camera image shapes) and formats it into the standard LeRobot feature
|
|
||||||
specification.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hw_features (dict): Dictionary mapping feature names to their type (float for
|
|
||||||
joints) or shape (tuple for images).
|
|
||||||
prefix (str): The prefix to add to the feature keys (e.g., "observation"
|
|
||||||
or "action").
|
|
||||||
use_video (bool): If True, image features are marked as "video", otherwise "image".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A LeRobot features dictionary.
|
|
||||||
"""
|
|
||||||
features = {}
|
|
||||||
joint_fts = {
|
|
||||||
key: ftype
|
|
||||||
for key, ftype in hw_features.items()
|
|
||||||
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
|
||||||
}
|
|
||||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
|
||||||
|
|
||||||
if joint_fts and prefix == ACTION:
|
|
||||||
features[prefix] = {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (len(joint_fts),),
|
|
||||||
"names": list(joint_fts),
|
|
||||||
}
|
|
||||||
|
|
||||||
if joint_fts and prefix == OBS_STR:
|
|
||||||
features[f"{prefix}.state"] = {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (len(joint_fts),),
|
|
||||||
"names": list(joint_fts),
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, shape in cam_fts.items():
|
|
||||||
features[f"{prefix}.images.{key}"] = {
|
|
||||||
"dtype": "video" if use_video else "image",
|
|
||||||
"shape": shape,
|
|
||||||
"names": ["height", "width", "channels"],
|
|
||||||
}
|
|
||||||
|
|
||||||
_validate_feature_names(features)
|
|
||||||
return features
|
|
||||||
|
|
||||||
|
|
||||||
def build_dataset_frame(
|
|
||||||
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
|
|
||||||
) -> dict[str, np.ndarray]:
|
|
||||||
"""Construct a single data frame from raw values based on dataset features.
|
|
||||||
|
|
||||||
A "frame" is a dictionary containing all the data for a single timestep,
|
|
||||||
formatted as numpy arrays according to the feature specification.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ds_features (dict): The LeRobot dataset features dictionary.
|
|
||||||
values (dict): A dictionary of raw values from the hardware/environment.
|
|
||||||
prefix (str): The prefix to filter features by (e.g., "observation"
|
|
||||||
or "action").
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A dictionary representing a single frame of data.
|
|
||||||
"""
|
|
||||||
frame = {}
|
|
||||||
for key, ft in ds_features.items():
|
|
||||||
if key in DEFAULT_FEATURES or not key.startswith(prefix):
|
|
||||||
continue
|
|
||||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
|
||||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
|
||||||
elif ft["dtype"] in ["image", "video"]:
|
|
||||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
|
||||||
|
|
||||||
return frame
|
|
||||||
|
|
||||||
|
|
||||||
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
|
||||||
"""Convert dataset features to policy features.
|
|
||||||
|
|
||||||
This function transforms the dataset's feature specification into a format
|
|
||||||
that a policy can use, classifying features by type (e.g., visual, state,
|
|
||||||
action) and ensuring correct shapes (e.g., channel-first for images).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features (dict): The LeRobot dataset features dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A dictionary mapping feature keys to `PolicyFeature` objects.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If an image feature does not have a 3D shape.
|
|
||||||
"""
|
|
||||||
# TODO(aliberts): Implement "type" in dataset features and simplify this
|
|
||||||
policy_features = {}
|
|
||||||
for key, ft in features.items():
|
|
||||||
shape = ft["shape"]
|
|
||||||
if ft["dtype"] in ["image", "video"]:
|
|
||||||
type = FeatureType.VISUAL
|
|
||||||
if len(shape) != 3:
|
|
||||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
|
||||||
|
|
||||||
names = ft["names"]
|
|
||||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
|
||||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
|
||||||
shape = (shape[2], shape[0], shape[1])
|
|
||||||
elif key == OBS_ENV_STATE:
|
|
||||||
type = FeatureType.ENV
|
|
||||||
elif key.startswith(OBS_STR):
|
|
||||||
type = FeatureType.STATE
|
|
||||||
elif key.startswith(ACTION):
|
|
||||||
type = FeatureType.ACTION
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
policy_features[key] = PolicyFeature(
|
|
||||||
type=type,
|
|
||||||
shape=shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
return policy_features
|
|
||||||
|
|
||||||
|
|
||||||
def combine_feature_dicts(*dicts: dict) -> dict:
|
|
||||||
"""Merge LeRobot grouped feature dicts.
|
|
||||||
|
|
||||||
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
|
|
||||||
- For others (e.g. `observation.images.*`), the last one wins (if they are identical).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*dicts: A variable number of LeRobot feature dictionaries to merge.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A single merged feature dictionary.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If there's a dtype mismatch for a feature being merged.
|
|
||||||
"""
|
|
||||||
out: dict = {}
|
|
||||||
for d in dicts:
|
|
||||||
for key, value in d.items():
|
|
||||||
if not isinstance(value, dict):
|
|
||||||
out[key] = value
|
|
||||||
continue
|
|
||||||
|
|
||||||
dtype = value.get("dtype")
|
|
||||||
shape = value.get("shape")
|
|
||||||
is_vector = (
|
|
||||||
dtype not in ("image", "video", "string")
|
|
||||||
and isinstance(shape, tuple)
|
|
||||||
and len(shape) == 1
|
|
||||||
and "names" in value
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_vector:
|
|
||||||
# Initialize or retrieve the accumulating dict for this feature key
|
|
||||||
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
|
|
||||||
# Ensure consistent data types across merged entries
|
|
||||||
if "dtype" in target and dtype != target["dtype"]:
|
|
||||||
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
|
|
||||||
|
|
||||||
# Merge feature names: append only new ones to preserve order without duplicates
|
|
||||||
seen = set(target["names"])
|
|
||||||
for n in value["names"]:
|
|
||||||
if n not in seen:
|
|
||||||
target["names"].append(n)
|
|
||||||
seen.add(n)
|
|
||||||
# Recompute the shape to reflect the updated number of features
|
|
||||||
target["shape"] = (len(target["names"]),)
|
|
||||||
else:
|
|
||||||
# For images/videos and non-1D entries: override with the latest definition
|
|
||||||
out[key] = value
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def create_empty_dataset_info(
|
def create_empty_dataset_info(
|
||||||
codebase_version: str,
|
codebase_version: str,
|
||||||
fps: int,
|
fps: int,
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ from collections.abc import Sequence
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from lerobot.configs.types import PipelineFeatureType
|
from lerobot.configs.types import PipelineFeatureType
|
||||||
from lerobot.datasets.feature_utils import hw_to_dataset_features
|
|
||||||
from lerobot.processor import DataProcessorPipeline
|
from lerobot.processor import DataProcessorPipeline
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||||
|
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||||
|
|
||||||
|
|
||||||
def create_initial_features(
|
def create_initial_features(
|
||||||
|
|||||||
@@ -93,14 +93,6 @@ LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
|||||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||||
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
||||||
|
|
||||||
DEFAULT_FEATURES = {
|
|
||||||
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
|
||||||
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
|
||||||
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
|
||||||
"index": {"dtype": "int64", "shape": (1,), "names": None},
|
|
||||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def has_legacy_hub_download_metadata(root: Path) -> bool:
|
def has_legacy_hub_download_metadata(root: Path) -> bool:
|
||||||
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
|
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
|
||||||
|
|||||||
@@ -29,24 +29,17 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
import torchvision
|
import torchvision
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
require_package("diffusers", extra="training")
|
from lerobot.policies.utils import (
|
||||||
|
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler # noqa: E402
|
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler # noqa: E402
|
|
||||||
from torch import Tensor, nn # noqa: E402
|
|
||||||
|
|
||||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig # noqa: E402
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy # noqa: E402
|
|
||||||
from lerobot.policies.utils import ( # noqa: E402
|
|
||||||
get_device_from_parameters,
|
get_device_from_parameters,
|
||||||
get_dtype_from_parameters,
|
get_dtype_from_parameters,
|
||||||
get_output_shape,
|
get_output_shape,
|
||||||
populate_queues,
|
populate_queues,
|
||||||
)
|
)
|
||||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE # noqa: E402
|
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
|
||||||
class DiffusionPolicy(PreTrainedPolicy):
|
class DiffusionPolicy(PreTrainedPolicy):
|
||||||
@@ -156,11 +149,17 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
return loss, None
|
return loss, None
|
||||||
|
|
||||||
|
|
||||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
def _make_noise_scheduler(name: str, **kwargs: dict):
|
||||||
"""
|
"""
|
||||||
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
||||||
to the scheduler.
|
to the scheduler.
|
||||||
"""
|
"""
|
||||||
|
from lerobot.utils.import_utils import require_package
|
||||||
|
|
||||||
|
require_package("diffusers", extra="training")
|
||||||
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
|
||||||
if name == "DDPM":
|
if name == "DDPM":
|
||||||
return DDPMScheduler(**kwargs)
|
return DDPMScheduler(**kwargs)
|
||||||
elif name == "DDIM":
|
elif name == "DDIM":
|
||||||
|
|||||||
@@ -495,7 +495,7 @@ def make_policy(
|
|||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if ds_meta is not None:
|
if ds_meta is not None:
|
||||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||||
|
|
||||||
features = dataset_to_policy_features(ds_meta.features)
|
features = dataset_to_policy_features(ds_meta.features)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -34,17 +34,10 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
import torchvision
|
import torchvision
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||||
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
require_package("diffusers", extra="training")
|
|
||||||
|
|
||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler # noqa: E402
|
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler # noqa: E402
|
|
||||||
from torch import Tensor # noqa: E402
|
|
||||||
|
|
||||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig # noqa: E402
|
|
||||||
from lerobot.utils.import_utils import _transformers_available # noqa: E402
|
|
||||||
|
|
||||||
# Conditional import for type checking and lazy loading
|
# Conditional import for type checking and lazy loading
|
||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
@@ -52,9 +45,9 @@ if TYPE_CHECKING or _transformers_available:
|
|||||||
else:
|
else:
|
||||||
CLIPTextModel = None
|
CLIPTextModel = None
|
||||||
CLIPVisionModel = None
|
CLIPVisionModel = None
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy # noqa: E402
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import populate_queues # noqa: E402
|
from lerobot.policies.utils import populate_queues
|
||||||
from lerobot.utils.constants import ( # noqa: E402
|
from lerobot.utils.constants import (
|
||||||
ACTION,
|
ACTION,
|
||||||
OBS_IMAGES,
|
OBS_IMAGES,
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
@@ -648,6 +641,12 @@ class DiffusionObjective(nn.Module):
|
|||||||
"prediction_type": config.prediction_type,
|
"prediction_type": config.prediction_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import require_package
|
||||||
|
|
||||||
|
require_package("diffusers", extra="training")
|
||||||
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
|
||||||
if config.noise_scheduler_type == "DDPM":
|
if config.noise_scheduler_type == "DDPM":
|
||||||
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
|
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
|
||||||
elif config.noise_scheduler_type == "DDIM":
|
elif config.noise_scheduler_type == "DDIM":
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ def build_inference_frame(
|
|||||||
Returns:
|
Returns:
|
||||||
A dictionary of preprocessed tensors ready for model inference.
|
A dictionary of preprocessed tensors ready for model inference.
|
||||||
"""
|
"""
|
||||||
from lerobot.datasets.feature_utils import build_dataset_frame
|
from lerobot.utils.feature_utils import build_dataset_frame
|
||||||
|
|
||||||
# Extracts the correct keys from the incoming raw observation
|
# Extracts the correct keys from the incoming raw observation
|
||||||
observation = build_dataset_frame(ds_features, observation, prefix=OBS_STR)
|
observation = build_dataset_frame(ds_features, observation, prefix=OBS_STR)
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.datasets.factory import IMAGENET_STATS
|
|
||||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
|
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
@@ -40,6 +39,7 @@ from lerobot.processor import (
|
|||||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
from lerobot.types import EnvTransition, TransitionKey
|
from lerobot.types import EnvTransition, TransitionKey
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
|
IMAGENET_STATS,
|
||||||
OBS_IMAGES,
|
OBS_IMAGES,
|
||||||
OBS_PREFIX,
|
OBS_PREFIX,
|
||||||
OBS_STATE,
|
OBS_STATE,
|
||||||
|
|||||||
@@ -62,8 +62,7 @@ from torch.optim.optimizer import Optimizer
|
|||||||
from lerobot.cameras import opencv # noqa: F401
|
from lerobot.cameras import opencv # noqa: F401
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||||
from lerobot.datasets.factory import make_dataset
|
from lerobot.datasets import LeRobotDataset, make_dataset
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
from lerobot.policies.factory import make_policy
|
from lerobot.policies.factory import make_policy
|
||||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||||
|
|||||||
@@ -44,10 +44,9 @@ from huggingface_hub import HfApi
|
|||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.datasets import LeRobotDataset
|
from lerobot.datasets import LeRobotDataset, write_stats
|
||||||
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
|
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
|
||||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION
|
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION
|
||||||
from lerobot.datasets.io_utils import write_stats
|
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -85,11 +85,13 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
|
|||||||
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.datasets import LeRobotDataset
|
from lerobot.datasets import (
|
||||||
from lerobot.datasets.feature_utils import build_dataset_frame, combine_feature_dicts
|
LeRobotDataset,
|
||||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
VideoEncodingManager,
|
||||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
aggregate_pipeline_dataset_features,
|
||||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
create_initial_features,
|
||||||
|
safe_stop_image_writer,
|
||||||
|
)
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.rtc import ActionInterpolator
|
from lerobot.policies.rtc import ActionInterpolator
|
||||||
@@ -143,6 +145,7 @@ from lerobot.utils.control_utils import (
|
|||||||
sanity_check_dataset_robot_compatibility,
|
sanity_check_dataset_robot_compatibility,
|
||||||
)
|
)
|
||||||
from lerobot.utils.device_utils import get_safe_torch_device
|
from lerobot.utils.device_utils import get_safe_torch_device
|
||||||
|
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||||
from lerobot.utils.import_utils import register_third_party_plugins
|
from lerobot.utils.import_utils import register_third_party_plugins
|
||||||
from lerobot.utils.robot_utils import precise_sleep
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
from lerobot.utils.utils import (
|
from lerobot.utils.utils import (
|
||||||
|
|||||||
@@ -13,45 +13,44 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
if TYPE_CHECKING:
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
require_package("accelerate", extra="training")
|
import torch
|
||||||
|
from termcolor import colored
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch # noqa: E402
|
from lerobot.configs import parser
|
||||||
from accelerate import Accelerator # noqa: E402
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from termcolor import colored # noqa: E402
|
from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||||
from torch.optim import Optimizer # noqa: E402
|
from lerobot.envs.factory import make_env, make_env_pre_post_processors
|
||||||
from tqdm import tqdm # noqa: E402
|
from lerobot.envs.utils import close_envs
|
||||||
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.configs import parser # noqa: E402
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
from lerobot.configs.train import TrainPipelineConfig # noqa: E402
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.datasets import EpisodeAwareSampler # noqa: E402
|
from lerobot.rl.wandb_utils import WandBLogger
|
||||||
from lerobot.datasets.factory import make_dataset # noqa: E402
|
from lerobot.scripts.lerobot_eval import eval_policy_all
|
||||||
from lerobot.envs.factory import make_env, make_env_pre_post_processors # noqa: E402
|
from lerobot.utils.import_utils import register_third_party_plugins
|
||||||
from lerobot.envs.utils import close_envs # noqa: E402
|
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler # noqa: E402
|
from lerobot.utils.random_utils import set_seed
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors # noqa: E402
|
from lerobot.utils.train_utils import (
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy # noqa: E402
|
|
||||||
from lerobot.rl.wandb_utils import WandBLogger # noqa: E402
|
|
||||||
from lerobot.scripts.lerobot_eval import eval_policy_all # noqa: E402
|
|
||||||
from lerobot.utils.import_utils import register_third_party_plugins # noqa: E402
|
|
||||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker # noqa: E402
|
|
||||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
|
||||||
from lerobot.utils.train_utils import ( # noqa: E402
|
|
||||||
get_step_checkpoint_dir,
|
get_step_checkpoint_dir,
|
||||||
get_step_identifier,
|
get_step_identifier,
|
||||||
load_training_state,
|
load_training_state,
|
||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
update_last_checkpoint,
|
update_last_checkpoint,
|
||||||
)
|
)
|
||||||
from lerobot.utils.utils import ( # noqa: E402
|
from lerobot.utils.utils import (
|
||||||
cycle,
|
cycle,
|
||||||
format_big_number,
|
format_big_number,
|
||||||
has_method,
|
has_method,
|
||||||
@@ -171,6 +170,11 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||||
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
||||||
"""
|
"""
|
||||||
|
from lerobot.utils.import_utils import require_package
|
||||||
|
|
||||||
|
require_package("accelerate", extra="training")
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
cfg.validate()
|
cfg.validate()
|
||||||
|
|
||||||
# Create Accelerator if not provided
|
# Create Accelerator if not provided
|
||||||
|
|||||||
@@ -75,6 +75,21 @@ default_calibration_path = HF_LEROBOT_HOME / "calibration"
|
|||||||
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
|
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
|
||||||
|
|
||||||
|
|
||||||
|
# Dataset meta-features (auto-populated by the recording pipeline)
|
||||||
|
DEFAULT_FEATURES = {
|
||||||
|
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
||||||
|
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
"index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
# ImageNet normalization constants
|
||||||
|
IMAGENET_STATS = {
|
||||||
|
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||||
|
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
|
||||||
|
}
|
||||||
|
|
||||||
# streaming datasets
|
# streaming datasets
|
||||||
LOOKBACK_BACKTRACKTABLE = 100
|
LOOKBACK_BACKTRACKTABLE = 100
|
||||||
LOOKAHEAD_BACKTRACKTABLE = 100
|
LOOKAHEAD_BACKTRACKTABLE = 100
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ def sanity_check_dataset_robot_compatibility(
|
|||||||
require_package("deepdiff", extra="hardware")
|
require_package("deepdiff", extra="hardware")
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
|
|
||||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||||
|
|||||||
@@ -0,0 +1,222 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Lightweight feature-manipulation utilities.
|
||||||
|
|
||||||
|
These functions are intentionally kept free of heavy dependencies (e.g. the
|
||||||
|
HuggingFace ``datasets`` library) so that they can be imported from anywhere
|
||||||
|
in the codebase – including modules that are part of the *minimal* install –
|
||||||
|
without triggering the ``lerobot.datasets`` package guard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
from lerobot.utils.constants import ACTION, DEFAULT_FEATURES, OBS_ENV_STATE, OBS_STR
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_feature_names(features: dict[str, dict]) -> None:
|
||||||
|
"""Validate that feature names do not contain invalid characters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (dict): The LeRobot features dictionary.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any feature name contains '/'.
|
||||||
|
"""
|
||||||
|
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
|
||||||
|
if invalid_features:
|
||||||
|
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
|
||||||
|
|
||||||
|
|
||||||
|
def hw_to_dataset_features(
|
||||||
|
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
|
||||||
|
) -> dict[str, dict]:
|
||||||
|
"""Convert hardware-specific features to a LeRobot dataset feature dictionary.
|
||||||
|
|
||||||
|
This function takes a dictionary describing hardware outputs (like joint states
|
||||||
|
or camera image shapes) and formats it into the standard LeRobot feature
|
||||||
|
specification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hw_features (dict): Dictionary mapping feature names to their type (float for
|
||||||
|
joints) or shape (tuple for images).
|
||||||
|
prefix (str): The prefix to add to the feature keys (e.g., "observation"
|
||||||
|
or "action").
|
||||||
|
use_video (bool): If True, image features are marked as "video", otherwise "image".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A LeRobot features dictionary.
|
||||||
|
"""
|
||||||
|
features = {}
|
||||||
|
joint_fts = {
|
||||||
|
key: ftype
|
||||||
|
for key, ftype in hw_features.items()
|
||||||
|
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
||||||
|
}
|
||||||
|
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||||
|
|
||||||
|
if joint_fts and prefix == ACTION:
|
||||||
|
features[prefix] = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(joint_fts),),
|
||||||
|
"names": list(joint_fts),
|
||||||
|
}
|
||||||
|
|
||||||
|
if joint_fts and prefix == OBS_STR:
|
||||||
|
features[f"{prefix}.state"] = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(joint_fts),),
|
||||||
|
"names": list(joint_fts),
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, shape in cam_fts.items():
|
||||||
|
features[f"{prefix}.images.{key}"] = {
|
||||||
|
"dtype": "video" if use_video else "image",
|
||||||
|
"shape": shape,
|
||||||
|
"names": ["height", "width", "channels"],
|
||||||
|
}
|
||||||
|
|
||||||
|
_validate_feature_names(features)
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def build_dataset_frame(
|
||||||
|
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
|
||||||
|
) -> dict[str, np.ndarray]:
|
||||||
|
"""Construct a single data frame from raw values based on dataset features.
|
||||||
|
|
||||||
|
A "frame" is a dictionary containing all the data for a single timestep,
|
||||||
|
formatted as numpy arrays according to the feature specification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_features (dict): The LeRobot dataset features dictionary.
|
||||||
|
values (dict): A dictionary of raw values from the hardware/environment.
|
||||||
|
prefix (str): The prefix to filter features by (e.g., "observation"
|
||||||
|
or "action").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary representing a single frame of data.
|
||||||
|
"""
|
||||||
|
frame = {}
|
||||||
|
for key, ft in ds_features.items():
|
||||||
|
if key in DEFAULT_FEATURES or not key.startswith(prefix):
|
||||||
|
continue
|
||||||
|
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||||
|
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||||
|
elif ft["dtype"] in ["image", "video"]:
|
||||||
|
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||||
|
"""Convert dataset features to policy features.
|
||||||
|
|
||||||
|
This function transforms the dataset's feature specification into a format
|
||||||
|
that a policy can use, classifying features by type (e.g., visual, state,
|
||||||
|
action) and ensuring correct shapes (e.g., channel-first for images).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (dict): The LeRobot dataset features dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary mapping feature keys to `PolicyFeature` objects.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If an image feature does not have a 3D shape.
|
||||||
|
"""
|
||||||
|
# TODO(aliberts): Implement "type" in dataset features and simplify this
|
||||||
|
policy_features = {}
|
||||||
|
for key, ft in features.items():
|
||||||
|
shape = ft["shape"]
|
||||||
|
if ft["dtype"] in ["image", "video"]:
|
||||||
|
type = FeatureType.VISUAL
|
||||||
|
if len(shape) != 3:
|
||||||
|
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||||
|
|
||||||
|
names = ft["names"]
|
||||||
|
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||||
|
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||||
|
shape = (shape[2], shape[0], shape[1])
|
||||||
|
elif key == OBS_ENV_STATE:
|
||||||
|
type = FeatureType.ENV
|
||||||
|
elif key.startswith(OBS_STR):
|
||||||
|
type = FeatureType.STATE
|
||||||
|
elif key.startswith(ACTION):
|
||||||
|
type = FeatureType.ACTION
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
policy_features[key] = PolicyFeature(
|
||||||
|
type=type,
|
||||||
|
shape=shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy_features
|
||||||
|
|
||||||
|
|
||||||
|
def combine_feature_dicts(*dicts: dict) -> dict:
|
||||||
|
"""Merge LeRobot grouped feature dicts.
|
||||||
|
|
||||||
|
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
|
||||||
|
- For others (e.g. `observation.images.*`), the last one wins (if they are identical).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*dicts: A variable number of LeRobot feature dictionaries to merge.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A single merged feature dictionary.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there's a dtype mismatch for a feature being merged.
|
||||||
|
"""
|
||||||
|
out: dict = {}
|
||||||
|
for d in dicts:
|
||||||
|
for key, value in d.items():
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
out[key] = value
|
||||||
|
continue
|
||||||
|
|
||||||
|
dtype = value.get("dtype")
|
||||||
|
shape = value.get("shape")
|
||||||
|
is_vector = (
|
||||||
|
dtype not in ("image", "video", "string")
|
||||||
|
and isinstance(shape, tuple)
|
||||||
|
and len(shape) == 1
|
||||||
|
and "names" in value
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_vector:
|
||||||
|
# Initialize or retrieve the accumulating dict for this feature key
|
||||||
|
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
|
||||||
|
# Ensure consistent data types across merged entries
|
||||||
|
if "dtype" in target and dtype != target["dtype"]:
|
||||||
|
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
|
||||||
|
|
||||||
|
# Merge feature names: append only new ones to preserve order without duplicates
|
||||||
|
seen = set(target["names"])
|
||||||
|
for n in value["names"]:
|
||||||
|
if n not in seen:
|
||||||
|
target["names"].append(n)
|
||||||
|
seen.add(n)
|
||||||
|
# Recompute the shape to reflect the updated number of features
|
||||||
|
target["shape"] = (len(target["names"]),)
|
||||||
|
else:
|
||||||
|
# For images/videos and non-1D entries: override with the latest definition
|
||||||
|
out[key] = value
|
||||||
|
return out
|
||||||
@@ -90,7 +90,8 @@ def require_package(pkg_name: str, extra: str, import_name: str | None = None) -
|
|||||||
_require_package_cache[cache_key] = is_package_available(pkg_name, import_name)
|
_require_package_cache[cache_key] = is_package_available(pkg_name, import_name)
|
||||||
if not _require_package_cache[cache_key]:
|
if not _require_package_cache[cache_key]:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"'{pkg_name}' is required but not installed. Install it with: pip install 'lerobot[{extra}]'"
|
f"'{pkg_name}' is required but not installed. Install it with: "
|
||||||
|
f"pip install 'lerobot[{extra}]' (or uv pip install 'lerobot[{extra}]')"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -81,6 +81,8 @@ def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None:
|
|||||||
stream.height = height
|
stream.height = height
|
||||||
stream.pix_fmt = "yuv420p"
|
stream.pix_fmt = "yuv420p"
|
||||||
for frame_array in stacked_frames:
|
for frame_array in stacked_frames:
|
||||||
|
if height != orig_height or width != orig_width:
|
||||||
|
frame_array = frame_array[:height, :width]
|
||||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||||
for packet in stream.encode(frame):
|
for packet in stream.encode(frame):
|
||||||
container.mux(packet)
|
container.mux(packet)
|
||||||
|
|||||||
@@ -292,9 +292,8 @@ class SuppressProgressBars:
|
|||||||
|
|
||||||
disable_progress_bar()
|
disable_progress_bar()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.getLogger(__name__).info(
|
logging.getLogger(__name__).debug(
|
||||||
"SuppressProgressBars is a no-op because 'datasets' is not installed. "
|
"SuppressProgressBars is a no-op because 'datasets' is not installed."
|
||||||
"Install it with: pip install 'lerobot[dataset]'"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from safetensors.torch import save_file
|
|||||||
|
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.datasets.factory import make_dataset
|
from lerobot.datasets import make_dataset
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors
|
from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors
|
||||||
from lerobot.utils.constants import OBS_STR
|
from lerobot.utils.constants import OBS_STR
|
||||||
|
|||||||
@@ -19,10 +19,10 @@ import torch
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import DatasetCard
|
from huggingface_hub import DatasetCard
|
||||||
|
|
||||||
from lerobot.datasets.feature_utils import combine_feature_dicts
|
|
||||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||||
from lerobot.datasets.utils import create_lerobot_dataset_card
|
from lerobot.datasets.utils import create_lerobot_dataset_card
|
||||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||||
|
from lerobot.utils.feature_utils import combine_feature_dicts
|
||||||
|
|
||||||
|
|
||||||
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
|
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ from torchvision.transforms import v2
|
|||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.datasets.factory import make_dataset
|
from lerobot.datasets import make_dataset
|
||||||
from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features
|
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
@@ -47,6 +47,7 @@ from lerobot.policies.factory import make_policy_config
|
|||||||
from lerobot.robots import make_robot_from_config
|
from lerobot.robots import make_robot_from_config
|
||||||
from lerobot.transforms import ImageTransforms, ImageTransformsConfig
|
from lerobot.transforms import ImageTransforms, ImageTransformsConfig
|
||||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||||
|
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||||
from tests.mocks.mock_robot import MockRobotConfig
|
from tests.mocks.mock_robot import MockRobotConfig
|
||||||
from tests.utils import require_x86_64_kernel
|
from tests.utils import require_x86_64_kernel
|
||||||
|
|||||||
@@ -27,8 +27,7 @@ from lerobot import available_policies
|
|||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.datasets.factory import make_dataset
|
from lerobot.datasets import make_dataset
|
||||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
|
||||||
from lerobot.envs.factory import make_env, make_env_config
|
from lerobot.envs.factory import make_env, make_env_config
|
||||||
from lerobot.envs.utils import close_envs, preprocess_observation
|
from lerobot.envs.utils import close_envs, preprocess_observation
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
@@ -44,6 +43,7 @@ from lerobot.policies.pretrained import PreTrainedPolicy
|
|||||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
|
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
|
||||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
|
from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||||
from lerobot.utils.random_utils import seeded_context
|
from lerobot.utils.random_utils import seeded_context
|
||||||
from lerobot.utils.utils import cycle
|
from lerobot.utils.utils import cycle
|
||||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||||
|
|||||||
Reference in New Issue
Block a user