mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 16:17:15 +00:00
temp_training
This commit is contained in:
@@ -72,10 +72,11 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
raise ValueError(
|
||||
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
|
||||
)
|
||||
if features != meta.features:
|
||||
raise ValueError(
|
||||
f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||
)
|
||||
# TODO: Temporarily disabled for merging datasets with different features (e.g. shirt_id)
|
||||
# if features != meta.features:
|
||||
# raise ValueError(
|
||||
# f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||
# )
|
||||
|
||||
return fps, robot_type, features
|
||||
|
||||
|
||||
@@ -560,7 +560,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episodes: list[int] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
tolerance_s: float = 1e-2,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
@@ -1563,7 +1563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
root: str | Path | None = None,
|
||||
robot_type: str | None = None,
|
||||
use_videos: bool = True,
|
||||
tolerance_s: float = 1e-4,
|
||||
tolerance_s: float = 1e-2,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
|
||||
@@ -61,8 +61,6 @@ class PI05Config(PreTrainedConfig):
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
|
||||
@@ -337,13 +337,28 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
# Filter out episodes - hardcoded list of bad episodes to discard
|
||||
episodes_to_discard = {
|
||||
133, 134, 502, 565, 568, 657, 910, 944, 1039, 1209, 1346, 1360, 1379,
|
||||
1605, 1690, 1790, 2105, 2106, 2122, 2118, 2156, 2575, 2764, 2876, 2925,
|
||||
3100, 3381, 3405, 3406, 68, 1214, 1456,
|
||||
}
|
||||
all_episodes = set(range(dataset.meta.total_episodes))
|
||||
episodes_to_use = dataset.episodes # May be None (all episodes) or a subset
|
||||
# If dataset.episodes is already filtered, start from that subset
|
||||
if episodes_to_use is not None:
|
||||
episodes_to_use = [ep for ep in episodes_to_use if ep not in episodes_to_discard]
|
||||
else:
|
||||
episodes_to_use = sorted(all_episodes - episodes_to_discard)
|
||||
|
||||
if hasattr(cfg.policy, "drop_n_last_frames") or episodes_to_use is not None:
|
||||
shuffle = False
|
||||
drop_n_last = getattr(cfg.policy, "drop_n_last_frames", 0)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
episode_indices_to_use=episodes_to_use,
|
||||
drop_n_last_frames=drop_n_last,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user