mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 08:17:02 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9449e68725 | |||
| 2b83956eb5 | |||
| 7309790d56 | |||
| 2201401c99 | |||
| 64773e7b22 |
@@ -167,9 +167,9 @@ jobs:
|
|||||||
|
|
||||||
# ── LIBERO TRAIN+EVAL SMOKE ──────────────────────────────────────────────
|
# ── LIBERO TRAIN+EVAL SMOKE ──────────────────────────────────────────────
|
||||||
# Train SmolVLA for 1 step (batch_size=1, dataset episode 0 only) then
|
# Train SmolVLA for 1 step (batch_size=1, dataset episode 0 only) then
|
||||||
# immediately runs eval inside the training loop (eval_freq=1, 1 episode).
|
# immediately runs eval inside the training loop (env_eval_freq=1, 1 episode).
|
||||||
# Tests the full train→eval-within-training pipeline end-to-end.
|
# Tests the full train→eval-within-training pipeline end-to-end.
|
||||||
- name: Run Libero train+eval smoke (1 step, eval_freq=1)
|
- name: Run Libero train+eval smoke (1 step, env_eval_freq=1)
|
||||||
if: env.HF_USER_TOKEN != ''
|
if: env.HF_USER_TOKEN != ''
|
||||||
run: |
|
run: |
|
||||||
docker run --name libero-train-smoke --gpus all \
|
docker run --name libero-train-smoke --gpus all \
|
||||||
@@ -196,7 +196,7 @@ jobs:
|
|||||||
--output_dir=/tmp/train-smoke \
|
--output_dir=/tmp/train-smoke \
|
||||||
--steps=1 \
|
--steps=1 \
|
||||||
--batch_size=1 \
|
--batch_size=1 \
|
||||||
--eval_freq=1 \
|
--env_eval_freq=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.use_async_envs=false \
|
--eval.use_async_envs=false \
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ test-act-ete-train:
|
|||||||
--dataset.episodes="[0]" \
|
--dataset.episodes="[0]" \
|
||||||
--batch_size=2 \
|
--batch_size=2 \
|
||||||
--steps=4 \
|
--steps=4 \
|
||||||
--eval_freq=2 \
|
--env_eval_freq=2 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
@@ -96,7 +96,7 @@ test-diffusion-ete-train:
|
|||||||
--dataset.episodes="[0]" \
|
--dataset.episodes="[0]" \
|
||||||
--batch_size=2 \
|
--batch_size=2 \
|
||||||
--steps=2 \
|
--steps=2 \
|
||||||
--eval_freq=2 \
|
--env_eval_freq=2 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
@@ -126,7 +126,7 @@ test-tdmpc-ete-train:
|
|||||||
--dataset.episodes="[0]" \
|
--dataset.episodes="[0]" \
|
||||||
--batch_size=2 \
|
--batch_size=2 \
|
||||||
--steps=2 \
|
--steps=2 \
|
||||||
--eval_freq=2 \
|
--env_eval_freq=2 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
@@ -161,7 +161,7 @@ test-smolvla-ete-train:
|
|||||||
--dataset.episodes="[0]" \
|
--dataset.episodes="[0]" \
|
||||||
--batch_size=2 \
|
--batch_size=2 \
|
||||||
--steps=4 \
|
--steps=4 \
|
||||||
--eval_freq=2 \
|
--env_eval_freq=2 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
|
|||||||
@@ -719,7 +719,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
|||||||
"num_workers": 4,
|
"num_workers": 4,
|
||||||
"steps": 5000,
|
"steps": 5000,
|
||||||
"log_freq": 10,
|
"log_freq": 10,
|
||||||
"eval_freq": 1000,
|
"env_eval_freq": 1000,
|
||||||
"save_freq": 1000,
|
"save_freq": 1000,
|
||||||
"save_checkpoint": true,
|
"save_checkpoint": true,
|
||||||
"seed": 2,
|
"seed": 2,
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ lerobot-train \
|
|||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval_freq=1000
|
--env_eval_freq=1000
|
||||||
```
|
```
|
||||||
|
|
||||||
## Reproducing published results
|
## Reproducing published results
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ lerobot-train \
|
|||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval_freq=1000
|
--env_eval_freq=1000
|
||||||
```
|
```
|
||||||
|
|
||||||
## Relationship to LIBERO
|
## Relationship to LIBERO
|
||||||
|
|||||||
@@ -120,11 +120,11 @@ lerobot-train \
|
|||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval_freq=1000
|
--env_eval_freq=1000
|
||||||
```
|
```
|
||||||
|
|
||||||
## Practical tips
|
## Practical tips
|
||||||
|
|
||||||
- Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context.
|
- Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context.
|
||||||
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
|
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
|
||||||
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
|
- Adjust `batch_size`, `steps`, and `env_eval_freq` to match your compute budget.
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ accelerate launch \
|
|||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--num_workers=4 \
|
--num_workers=4 \
|
||||||
--log_freq=20 \
|
--log_freq=20 \
|
||||||
--eval_freq=-1 \
|
--env_eval_freq=-1 \
|
||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
--save_freq=2000
|
--save_freq=2000
|
||||||
```
|
```
|
||||||
@@ -142,7 +142,7 @@ accelerate launch \
|
|||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--num_workers=4 \
|
--num_workers=4 \
|
||||||
--log_freq=20 \
|
--log_freq=20 \
|
||||||
--eval_freq=-1 \
|
--env_eval_freq=-1 \
|
||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
--save_freq=2000
|
--save_freq=2000
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -314,7 +314,7 @@ lerobot-train \
|
|||||||
--steps=30000 \
|
--steps=30000 \
|
||||||
--save_freq=1000 \
|
--save_freq=1000 \
|
||||||
--log_freq=100 \
|
--log_freq=100 \
|
||||||
--eval_freq=1000 \
|
--env_eval_freq=1000 \
|
||||||
--policy.type=multi_task_dit \
|
--policy.type=multi_task_dit \
|
||||||
--policy.device=cuda \
|
--policy.device=cuda \
|
||||||
--policy.horizon=32 \
|
--policy.horizon=32 \
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ lerobot-train \
|
|||||||
--output_dir=./outputs/smolvla_robocasa_CloseFridge \
|
--output_dir=./outputs/smolvla_robocasa_CloseFridge \
|
||||||
--steps=100000 \
|
--steps=100000 \
|
||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval_freq=5000 \
|
--env_eval_freq=5000 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=5 \
|
--eval.n_episodes=5 \
|
||||||
--save_freq=10000
|
--save_freq=10000
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ lerobot-train \
|
|||||||
--output_dir=./outputs/smolvla_vlabench_primitive \
|
--output_dir=./outputs/smolvla_vlabench_primitive \
|
||||||
--steps=100000 \
|
--steps=100000 \
|
||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval_freq=5000 \
|
--env_eval_freq=5000 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--save_freq=10000
|
--save_freq=10000
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ class DatasetConfig:
|
|||||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||||
return_uint8: bool = False
|
return_uint8: bool = False
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
|
# Fraction of episodes held out per task for offline evaluation (0.0 = disabled).
|
||||||
|
eval_split: float = 0.0
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.episodes is not None:
|
if self.episodes is not None:
|
||||||
|
|||||||
@@ -100,8 +100,13 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
prefetch_factor: int = 4
|
prefetch_factor: int = 4
|
||||||
persistent_workers: bool = True
|
persistent_workers: bool = True
|
||||||
steps: int = 100_000
|
steps: int = 100_000
|
||||||
eval_freq: int = 20_000
|
# Run policy in the simulation environment every N steps to measure reward/success (0 = disabled).
|
||||||
|
env_eval_freq: int = 20_000
|
||||||
log_freq: int = 200
|
log_freq: int = 200
|
||||||
|
# Compute eval loss on held-out episodes every N steps (0 = disabled). Requires eval_split > 0.
|
||||||
|
eval_steps: int = 0
|
||||||
|
# Cap on total eval samples, split uniformly across tasks (0 = use all held-out data).
|
||||||
|
max_eval_samples: int = 0
|
||||||
tolerance_s: float = 1e-4
|
tolerance_s: float = 1e-4
|
||||||
save_checkpoint: bool = True
|
save_checkpoint: bool = True
|
||||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from .dataset_tools import (
|
|||||||
remove_feature,
|
remove_feature,
|
||||||
split_dataset,
|
split_dataset,
|
||||||
)
|
)
|
||||||
from .factory import make_dataset, resolve_delta_timestamps
|
from .factory import make_dataset, make_train_eval_datasets, resolve_delta_timestamps
|
||||||
from .image_writer import safe_stop_image_writer
|
from .image_writer import safe_stop_image_writer
|
||||||
from .io_utils import load_episodes, write_stats
|
from .io_utils import load_episodes, write_stats
|
||||||
from .language import (
|
from .language import (
|
||||||
@@ -89,6 +89,7 @@ __all__ = [
|
|||||||
"get_feature_stats",
|
"get_feature_stats",
|
||||||
"load_episodes",
|
"load_episodes",
|
||||||
"make_dataset",
|
"make_dataset",
|
||||||
|
"make_train_eval_datasets",
|
||||||
"merge_datasets",
|
"merge_datasets",
|
||||||
"modify_features",
|
"modify_features",
|
||||||
"modify_tasks",
|
"modify_tasks",
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -130,3 +131,81 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def make_train_eval_datasets(
|
||||||
|
cfg: TrainPipelineConfig,
|
||||||
|
) -> tuple[LeRobotDataset | MultiLeRobotDataset, LeRobotDataset | None]:
|
||||||
|
"""Create train and optional eval datasets by splitting episodes based on eval_split.
|
||||||
|
|
||||||
|
The last ceil(n_episodes * eval_split) episodes per task are held out for evaluation.
|
||||||
|
If eval_split == 0.0, returns (full_dataset, None).
|
||||||
|
"""
|
||||||
|
full_dataset = make_dataset(cfg)
|
||||||
|
|
||||||
|
if cfg.dataset.eval_split == 0.0:
|
||||||
|
return full_dataset, None
|
||||||
|
|
||||||
|
base_episodes = (
|
||||||
|
full_dataset.episodes if full_dataset.episodes is not None else list(range(full_dataset.num_episodes))
|
||||||
|
)
|
||||||
|
|
||||||
|
episode_tasks = full_dataset.meta.episodes["tasks"]
|
||||||
|
task_to_episodes: dict[str, list[int]] = {}
|
||||||
|
for ep_idx in base_episodes:
|
||||||
|
task_key = episode_tasks[ep_idx][0] if episode_tasks[ep_idx] else ""
|
||||||
|
task_to_episodes.setdefault(task_key, []).append(ep_idx)
|
||||||
|
|
||||||
|
train_episodes, eval_episodes = [], []
|
||||||
|
for eps in task_to_episodes.values():
|
||||||
|
n_eval = math.ceil(len(eps) * cfg.dataset.eval_split)
|
||||||
|
train_episodes.extend(eps[: len(eps) - n_eval])
|
||||||
|
eval_episodes.extend(eps[len(eps) - n_eval :])
|
||||||
|
|
||||||
|
if not train_episodes:
|
||||||
|
raise ValueError(
|
||||||
|
f"eval_split={cfg.dataset.eval_split} leaves 0 training episodes from {len(base_episodes)} total."
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Train/eval split: {len(train_episodes)} train, {len(eval_episodes)} eval "
|
||||||
|
f"(eval_split={cfg.dataset.eval_split}, {len(task_to_episodes)} tasks)"
|
||||||
|
)
|
||||||
|
|
||||||
|
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, full_dataset.meta)
|
||||||
|
|
||||||
|
train_image_transforms = (
|
||||||
|
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = LeRobotDataset(
|
||||||
|
cfg.dataset.repo_id,
|
||||||
|
root=cfg.dataset.root,
|
||||||
|
episodes=train_episodes,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
image_transforms=train_image_transforms,
|
||||||
|
revision=cfg.dataset.revision,
|
||||||
|
video_backend=cfg.dataset.video_backend,
|
||||||
|
return_uint8=True,
|
||||||
|
tolerance_s=cfg.tolerance_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_dataset = LeRobotDataset(
|
||||||
|
cfg.dataset.repo_id,
|
||||||
|
root=cfg.dataset.root,
|
||||||
|
episodes=eval_episodes,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
image_transforms=None,
|
||||||
|
revision=cfg.dataset.revision,
|
||||||
|
video_backend=cfg.dataset.video_backend,
|
||||||
|
return_uint8=True,
|
||||||
|
tolerance_s=cfg.tolerance_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.dataset.use_imagenet_stats:
|
||||||
|
for ds in (train_dataset, eval_dataset):
|
||||||
|
for key in ds.meta.camera_keys:
|
||||||
|
for stats_type, stats in IMAGENET_STATS.items():
|
||||||
|
ds.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||||
|
|
||||||
|
return train_dataset, eval_dataset
|
||||||
|
|||||||
@@ -370,6 +370,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.reader.load_and_activate()
|
self.reader.load_and_activate()
|
||||||
return self.reader.hf_dataset
|
return self.reader.hf_dataset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def absolute_to_relative_idx(self) -> dict[int, int] | None:
|
||||||
|
"""Mapping from absolute frame indices to HF dataset row positions.
|
||||||
|
|
||||||
|
Non-None only for episode-filtered datasets where absolute indices
|
||||||
|
(from metadata) differ from row positions in the loaded HF dataset.
|
||||||
|
"""
|
||||||
|
reader = self._ensure_reader()
|
||||||
|
if reader.hf_dataset is None:
|
||||||
|
reader.load_and_activate()
|
||||||
|
return reader._absolute_to_relative_idx
|
||||||
|
|
||||||
# ── Writer-delegated methods ──────────────────────────────────────
|
# ── Writer-delegated methods ──────────────────────────────────────
|
||||||
|
|
||||||
def add_frame(self, frame: dict) -> None:
|
def add_frame(self, frame: dict) -> None:
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class EpisodeAwareSampler:
|
|||||||
drop_n_last_frames: int = 0,
|
drop_n_last_frames: int = 0,
|
||||||
shuffle: bool = False,
|
shuffle: bool = False,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
|
absolute_to_relative_idx: dict[int, int] | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -107,6 +108,7 @@ class EpisodeAwareSampler:
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self._epoch = 0
|
self._epoch = 0
|
||||||
self._start_index = 0
|
self._start_index = 0
|
||||||
|
self._absolute_to_relative = absolute_to_relative_idx
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def indices(self) -> list[int]:
|
def indices(self) -> list[int]:
|
||||||
@@ -132,7 +134,10 @@ class EpisodeAwareSampler:
|
|||||||
def _frame_index(self, position: int) -> int:
|
def _frame_index(self, position: int) -> int:
|
||||||
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
||||||
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
|
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
|
||||||
return int(self._starts[episode]) + position_in_episode
|
absolute_idx = int(self._starts[episode]) + position_in_episode
|
||||||
|
if self._absolute_to_relative is not None:
|
||||||
|
return self._absolute_to_relative[absolute_idx]
|
||||||
|
return absolute_idx
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[int]:
|
def __iter__(self) -> Iterator[int]:
|
||||||
# Advance epoch state eagerly, not on first consumption of the generator.
|
# Advance epoch state eagerly, not on first consumption of the generator.
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ from lerobot.common.train_utils import (
|
|||||||
from lerobot.common.wandb_utils import WandBLogger
|
from lerobot.common.wandb_utils import WandBLogger
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
|
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state
|
||||||
|
from lerobot.datasets.factory import make_train_eval_datasets
|
||||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||||
@@ -244,19 +245,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
|
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info("Creating dataset")
|
logging.info("Creating dataset")
|
||||||
dataset = make_dataset(cfg)
|
dataset, eval_dataset = make_train_eval_datasets(cfg)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# Other ranks read from the shared copy populated by the main process.
|
# Other ranks read from the shared copy populated by the main process.
|
||||||
if not is_main_process:
|
if not is_main_process:
|
||||||
dataset = make_dataset(cfg)
|
dataset, eval_dataset = make_train_eval_datasets(cfg)
|
||||||
|
|
||||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||||
eval_env = None
|
eval_env = None
|
||||||
if cfg.eval_freq > 0 and cfg.env is not None and is_main_process:
|
if cfg.env_eval_freq > 0 and cfg.env is not None and is_main_process:
|
||||||
logging.info("Creating env")
|
logging.info("Creating env")
|
||||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||||
|
|
||||||
@@ -406,6 +407,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
|
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
seed=cfg.seed if cfg.seed is not None else 0,
|
seed=cfg.seed if cfg.seed is not None else 0,
|
||||||
|
absolute_to_relative_idx=dataset.absolute_to_relative_idx,
|
||||||
)
|
)
|
||||||
if cfg.resume and step > 0:
|
if cfg.resume and step > 0:
|
||||||
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
|
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
|
||||||
@@ -455,6 +457,33 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Build eval dataloader if a held-out split exists
|
||||||
|
eval_dataloader = None
|
||||||
|
if eval_dataset is not None:
|
||||||
|
eval_ds = eval_dataset
|
||||||
|
if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"):
|
||||||
|
task_arr = eval_dataset.hf_dataset.data.column("task_index").to_numpy()
|
||||||
|
unique_tasks = sorted(set(task_arr.tolist()))
|
||||||
|
per_task = max(1, cfg.max_eval_samples // len(unique_tasks))
|
||||||
|
selected: list[int] = []
|
||||||
|
for t in unique_tasks:
|
||||||
|
frames = (task_arr == t).nonzero()[0][:per_task]
|
||||||
|
selected.extend(frames.tolist())
|
||||||
|
eval_ds = torch.utils.data.Subset(eval_dataset, selected)
|
||||||
|
|
||||||
|
eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
|
||||||
|
eval_dataloader = torch.utils.data.DataLoader(
|
||||||
|
eval_ds,
|
||||||
|
batch_size=cfg.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=cfg.num_workers,
|
||||||
|
pin_memory=device.type == "cuda",
|
||||||
|
drop_last=False,
|
||||||
|
collate_fn=eval_collate_fn,
|
||||||
|
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||||
|
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare everything with accelerator
|
# Prepare everything with accelerator
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||||
@@ -534,7 +563,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
train_tracker.step()
|
train_tracker.step()
|
||||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
||||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
is_env_eval_step = cfg.env_eval_freq > 0 and step % cfg.env_eval_freq == 0
|
||||||
|
is_eval_step = cfg.eval_steps > 0 and eval_dataloader is not None and step % cfg.eval_steps == 0
|
||||||
|
|
||||||
if is_log_step:
|
if is_log_step:
|
||||||
# Collective reduce must run on every rank, before the main-process gate below.
|
# Collective reduce must run on every rank, before the main-process gate below.
|
||||||
@@ -557,6 +587,27 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
wandb_logger.log_dict(wandb_log_dict, step)
|
wandb_logger.log_dict(wandb_log_dict, step)
|
||||||
train_tracker.reset_averages()
|
train_tracker.reset_averages()
|
||||||
|
|
||||||
|
if is_eval_step:
|
||||||
|
policy.eval()
|
||||||
|
eval_loss_sum = 0.0
|
||||||
|
n_eval_batches = 0
|
||||||
|
with torch.no_grad(), accelerator.autocast():
|
||||||
|
for eval_batch in eval_dataloader:
|
||||||
|
for cam_key in dataset.meta.camera_keys:
|
||||||
|
if cam_key in eval_batch and eval_batch[cam_key].dtype == torch.uint8:
|
||||||
|
eval_batch[cam_key] = eval_batch[cam_key].to(dtype=torch.float32) / 255.0
|
||||||
|
eval_batch = preprocessor(eval_batch)
|
||||||
|
loss, _ = policy.forward(eval_batch)
|
||||||
|
eval_loss_sum += loss.item()
|
||||||
|
n_eval_batches += 1
|
||||||
|
eval_loss = eval_loss_sum / max(n_eval_batches, 1)
|
||||||
|
policy.train()
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
logging.info(f"step {step}: eval_loss={eval_loss:.4f}")
|
||||||
|
if wandb_logger:
|
||||||
|
wandb_logger.log_dict({"eval_loss": eval_loss}, step=step, mode="eval")
|
||||||
|
|
||||||
if cfg.save_checkpoint and is_saving_step:
|
if cfg.save_checkpoint and is_saving_step:
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info(f"Checkpoint policy after step {step}")
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
@@ -579,7 +630,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if cfg.env and is_eval_step:
|
if cfg.env and is_env_eval_step:
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
step_id = get_step_identifier(step, cfg.steps)
|
step_id = get_step_identifier(step, cfg.steps)
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ class TestMultiGPUTraining:
|
|||||||
f"--output_dir={output_dir}",
|
f"--output_dir={output_dir}",
|
||||||
"--batch_size=4",
|
"--batch_size=4",
|
||||||
"--steps=10",
|
"--steps=10",
|
||||||
"--eval_freq=-1",
|
"--env_eval_freq=-1",
|
||||||
"--log_freq=5",
|
"--log_freq=5",
|
||||||
"--save_freq=10",
|
"--save_freq=10",
|
||||||
"--seed=42",
|
"--seed=42",
|
||||||
@@ -177,7 +177,7 @@ class TestMultiGPUTraining:
|
|||||||
f"--output_dir={output_dir}",
|
f"--output_dir={output_dir}",
|
||||||
"--batch_size=4",
|
"--batch_size=4",
|
||||||
"--steps=20",
|
"--steps=20",
|
||||||
"--eval_freq=-1",
|
"--env_eval_freq=-1",
|
||||||
"--log_freq=5",
|
"--log_freq=5",
|
||||||
"--save_freq=10",
|
"--save_freq=10",
|
||||||
"--seed=42",
|
"--seed=42",
|
||||||
|
|||||||
Reference in New Issue
Block a user