temp_training

This commit is contained in:
Michel Aractingi
2026-01-26 15:19:52 +00:00
parent 9e10eb4a77
commit 9cc203034e
4 changed files with 25 additions and 11 deletions
+5 -4
View File
@@ -72,10 +72,11 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
raise ValueError(
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
)
if features != meta.features:
raise ValueError(
f"Same features is expected, but got features={meta.features} instead of {features}."
)
# TODO: Temporarily disabled for merging datasets with different features (e.g. shirt_id)
# if features != meta.features:
# raise ValueError(
# f"Same features is expected, but got features={meta.features} instead of {features}."
# )
return fps, robot_type, features
+2 -2
View File
@@ -560,7 +560,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episodes: list[int] | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerance_s: float = 1e-4,
tolerance_s: float = 1e-2,
revision: str | None = None,
force_cache_sync: bool = False,
download_videos: bool = True,
@@ -1563,7 +1563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
root: str | Path | None = None,
robot_type: str | None = None,
use_videos: bool = True,
tolerance_s: float = 1e-4,
tolerance_s: float = 1e-2,
image_writer_processes: int = 0,
image_writer_threads: int = 0,
video_backend: str | None = None,
@@ -61,8 +61,6 @@ class PI05Config(PreTrainedConfig):
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
tokenizer_max_length: int = 200 # see openpi `__post_init__`
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
+18 -3
View File
@@ -337,13 +337,28 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
# Filter out episodes - hardcoded list of bad episodes to discard
episodes_to_discard = {
133, 134, 502, 565, 568, 657, 910, 944, 1039, 1209, 1346, 1360, 1379,
1605, 1690, 1790, 2105, 2106, 2122, 2118, 2156, 2575, 2764, 2876, 2925,
3100, 3381, 3405, 3406, 68, 1214, 1456,
}
all_episodes = set(range(dataset.meta.total_episodes))
episodes_to_use = dataset.episodes # May be None (all episodes) or a subset
# If dataset.episodes is already filtered, start from that subset
if episodes_to_use is not None:
episodes_to_use = [ep for ep in episodes_to_use if ep not in episodes_to_discard]
else:
episodes_to_use = sorted(all_episodes - episodes_to_discard)
if hasattr(cfg.policy, "drop_n_last_frames") or episodes_to_use is not None:
shuffle = False
drop_n_last = getattr(cfg.policy, "drop_n_last_frames", 0)
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
episode_indices_to_use=episodes_to_use,
drop_n_last_frames=drop_n_last,
shuffle=True,
)
else: