mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 00:07:03 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 287c823f13 |
@@ -79,8 +79,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||
pretrained_path: Path | None = None
|
||||
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained model version.
|
||||
pretrained_revision: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
|
||||
@@ -56,8 +56,6 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
device: str | None = None
|
||||
|
||||
pretrained_path: str | None = None
|
||||
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained reward model version.
|
||||
pretrained_revision: str | None = None
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str | None = None
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -709,7 +710,7 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
features = {**deepcopy(features), **DEFAULT_FEATURES}
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
|
||||
@@ -27,6 +27,7 @@ import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -1101,7 +1102,9 @@ def _copy_episodes_metadata_and_stats(
|
||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||
for key in dst_meta.video_keys:
|
||||
if key in src_dataset.meta.features:
|
||||
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
||||
dst_meta.info.features[key]["info"] = deepcopy(
|
||||
src_dataset.meta.info.features[key].get("info", {})
|
||||
)
|
||||
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
|
||||
@@ -252,7 +252,6 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
def make_pre_post_processors(
|
||||
policy_cfg: PreTrainedConfig,
|
||||
pretrained_path: str | None = None,
|
||||
pretrained_revision: str | None = None,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
@@ -310,7 +309,6 @@ def make_pre_post_processors(
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
revision=pretrained_revision,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
@@ -320,7 +318,6 @@ def make_pre_post_processors(
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
revision=pretrained_revision,
|
||||
)
|
||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
@@ -560,7 +557,6 @@ def make_policy(
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
kwargs["revision"] = cfg.pretrained_revision
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
elif cfg.pretrained_path and cfg.use_peft:
|
||||
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
|
||||
|
||||
@@ -124,7 +124,6 @@ def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel
|
||||
|
||||
if cfg.pretrained_path:
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
kwargs["revision"] = cfg.pretrained_revision
|
||||
reward_model = reward_cls.from_pretrained(**kwargs)
|
||||
else:
|
||||
reward_model = reward_cls(**kwargs)
|
||||
|
||||
@@ -345,7 +345,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
pretrained_revision=getattr(cfg.policy, "pretrained_revision", None),
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ from lerobot.robots import make_robot_from_config
|
||||
from lerobot.transforms import ImageTransforms, ImageTransformsConfig
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
@@ -133,6 +133,21 @@ def test_dataset_feature_with_forward_slash_raises_error():
|
||||
)
|
||||
|
||||
|
||||
def test_create_does_not_mutate_input_features(tmp_path, empty_lerobot_dataset_factory):
|
||||
# ``create`` must deep-copy features so a dataset built from another's features stays independent.
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds1", features=DUMMY_MOTOR_FEATURES, use_videos=False
|
||||
)
|
||||
dataset_copy = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds2", features=dataset.meta.features, use_videos=False
|
||||
)
|
||||
|
||||
original_shape = dataset.meta.info.features["state"]["shape"]
|
||||
dataset_copy.meta.info.features["state"]["shape"] = (999,)
|
||||
|
||||
assert dataset.meta.info.features["state"]["shape"] == original_shape
|
||||
|
||||
|
||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
Reference in New Issue
Block a user