mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-12 05:59:53 +00:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fcd8ab5800 | |||
| ee6eb745b8 | |||
| 27b482adf7 | |||
| 21d158e066 | |||
| 22991ed69a | |||
| 1adc7a0309 | |||
| f72fc3b4ba | |||
| dabf88ef9f | |||
| 2c47217825 | |||
| 9c502e204e | |||
| c55df19e6c | |||
| c91f345092 |
+1
-1
@@ -124,7 +124,7 @@ hardware = [
|
||||
"lerobot[deepdiff-dep]",
|
||||
]
|
||||
viz = [
|
||||
"rerun-sdk>=0.24.0,<0.27.0",
|
||||
"rerun-sdk>=0.24.0,<0.34.0",
|
||||
]
|
||||
# ── User-facing composite extras (map to CLI scripts) ─────
|
||||
# lerobot-record, lerobot-replay, lerobot-calibrate, lerobot-teleoperate, etc.
|
||||
|
||||
@@ -49,19 +49,8 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
|
||||
return output_dir / CHECKPOINTS_DIR / step_identifier
|
||||
|
||||
|
||||
def save_training_step(
|
||||
step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None
|
||||
) -> None:
|
||||
state: dict = {"step": step}
|
||||
# num_processes and batch_size are recorded so a resumed run can detect a changed world size or
|
||||
# batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that
|
||||
# produced `step`, since both scale how many sampler positions a step consumes (see
|
||||
# compute_sampler_state).
|
||||
if num_processes is not None:
|
||||
state["num_processes"] = num_processes
|
||||
if batch_size is not None:
|
||||
state["batch_size"] = batch_size
|
||||
write_json(state, save_dir / TRAINING_STEP)
|
||||
def save_training_step(step: int, save_dir: Path) -> None:
|
||||
write_json({"step": step}, save_dir / TRAINING_STEP)
|
||||
|
||||
|
||||
def load_training_step(save_dir: Path) -> int:
|
||||
@@ -69,16 +58,6 @@ def load_training_step(save_dir: Path) -> int:
|
||||
return training_step["step"]
|
||||
|
||||
|
||||
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
|
||||
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
|
||||
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
|
||||
|
||||
|
||||
def load_training_batch_size(checkpoint_dir: Path) -> int | None:
|
||||
"""Per-process batch size recorded at checkpoint time, or None for older checkpoints."""
|
||||
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size")
|
||||
|
||||
|
||||
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
||||
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
||||
if last_checkpoint_dir.is_symlink():
|
||||
@@ -96,8 +75,6 @@ def save_checkpoint(
|
||||
scheduler: LRScheduler | None = None,
|
||||
preprocessor: PolicyProcessorPipeline | None = None,
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -123,10 +100,6 @@ def save_checkpoint(
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
num_processes (int | None, optional): Distributed world size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
@@ -139,9 +112,7 @@ def save_checkpoint(
|
||||
preprocessor.save_pretrained(pretrained_dir)
|
||||
if postprocessor is not None:
|
||||
postprocessor.save_pretrained(pretrained_dir)
|
||||
save_training_state(
|
||||
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
|
||||
)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||
|
||||
|
||||
def save_training_state(
|
||||
@@ -149,8 +120,6 @@ def save_training_state(
|
||||
train_step: int,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Saves the training step, optimizer state, scheduler state, and rng state.
|
||||
@@ -162,12 +131,10 @@ def save_training_state(
|
||||
Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
||||
Defaults to None.
|
||||
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
|
||||
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
|
||||
"""
|
||||
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
|
||||
save_training_step(train_step, save_dir)
|
||||
save_rng_state(save_dir)
|
||||
if optimizer is not None:
|
||||
save_optimizer_state(optimizer, save_dir)
|
||||
|
||||
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||
from .sampler import EpisodeAwareSampler, compute_sampler_state
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
@@ -82,7 +82,6 @@ __all__ = [
|
||||
"aggregate_stats",
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"compute_sampler_state",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
|
||||
+32
-122
@@ -14,36 +14,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Iterator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodeAwareSampler:
|
||||
"""Sampler over episode frames that stores only per-episode boundaries.
|
||||
|
||||
Logical positions map to frame indices on the fly (O(num_episodes) construction memory)
|
||||
instead of materializing a Python list of every frame index.
|
||||
|
||||
Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order
|
||||
is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the
|
||||
global RNG (no `generator` to sync across distributed ranks), and `state_dict` /
|
||||
`load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and
|
||||
continuing from the saved offset. Each call to `__iter__` advances the epoch. During a
|
||||
resumed epoch, `__len__` still reports the full length.
|
||||
|
||||
Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict`
|
||||
set it explicitly. Within a single run callers should rely on exactly one of these mechanisms,
|
||||
not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same
|
||||
iterations would skip or repeat epochs. The training loop drives it purely through `__iter__`
|
||||
(via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration
|
||||
starts (e.g. on resume or in tests).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_indices: list[int],
|
||||
@@ -52,125 +30,57 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
):
|
||||
"""
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
Args:
|
||||
dataset_from_indices: Start index of each episode in the dataset.
|
||||
dataset_to_indices: End index of each episode in the dataset.
|
||||
episode_indices_to_use: Episode indices to use; None means all.
|
||||
drop_n_first_frames: Frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Frames to drop from the end of each episode.
|
||||
dataset_from_indices: List of indices containing the start of each episode in the dataset.
|
||||
dataset_to_indices: List of indices containing the end of each episode in the dataset.
|
||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||
Assumes that episodes are indexed from 0 to N-1.
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
seed: Seed the permutation is derived from (together with the epoch).
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
if drop_n_last_frames < 0:
|
||||
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
||||
|
||||
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
|
||||
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
|
||||
if from_indices.shape != to_indices.shape:
|
||||
raise ValueError(
|
||||
f"dataset_from_indices and dataset_to_indices must have the same length, "
|
||||
f"got {len(from_indices)} and {len(to_indices)}"
|
||||
)
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
ep_length = end_index - start_index
|
||||
if drop_n_first_frames + drop_n_last_frames >= ep_length:
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
ep_length,
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
continue
|
||||
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
||||
|
||||
used = np.ones(len(from_indices), dtype=bool)
|
||||
if episode_indices_to_use is not None:
|
||||
used = np.zeros(len(from_indices), dtype=bool)
|
||||
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
|
||||
|
||||
starts = from_indices + drop_n_first_frames
|
||||
lengths = to_indices - drop_n_last_frames - starts
|
||||
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
to_indices[episode_idx] - from_indices[episode_idx],
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
used &= lengths > 0
|
||||
if not used.any():
|
||||
if not indices:
|
||||
raise ValueError(
|
||||
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
||||
"All episodes were either filtered out or had too few frames."
|
||||
)
|
||||
|
||||
self._starts = starts[used]
|
||||
self._cum_lengths = np.cumsum(lengths[used])
|
||||
self._num_frames = int(self._cum_lengths[-1])
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self._epoch = 0
|
||||
self._start_index = 0
|
||||
|
||||
@property
|
||||
def indices(self) -> list[int]:
|
||||
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
|
||||
return [self._frame_index(k) for k in range(self._num_frames)]
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
self._epoch = epoch
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
return {"epoch": self._epoch, "start_index": self._start_index}
|
||||
|
||||
def load_state_dict(self, state: dict) -> None:
|
||||
self._epoch = state["epoch"]
|
||||
self._start_index = state["start_index"]
|
||||
|
||||
def _epoch_generator(self, epoch: int) -> torch.Generator:
|
||||
# Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both
|
||||
# and reproduces identically on every rank without touching the global RNG.
|
||||
epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0])
|
||||
return torch.Generator().manual_seed(epoch_seed)
|
||||
|
||||
def _frame_index(self, position: int) -> int:
|
||||
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
||||
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
|
||||
return int(self._starts[episode]) + position_in_episode
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
# Advance epoch state eagerly, not on first consumption of the generator.
|
||||
epoch, start = self._epoch, self._start_index
|
||||
self._epoch += 1
|
||||
self._start_index = 0
|
||||
return self._iter_epoch(epoch, start)
|
||||
|
||||
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch))
|
||||
for k in range(start, self._num_frames):
|
||||
yield self._frame_index(int(order[k]))
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for k in range(start, self._num_frames):
|
||||
yield self._frame_index(k)
|
||||
for i in self.indices:
|
||||
yield i
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._num_frames
|
||||
|
||||
|
||||
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
|
||||
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
|
||||
|
||||
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
|
||||
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
|
||||
per epoch (`even_batches` padding included). The start index provably stays below
|
||||
`num_frames`; the `min` is defensive.
|
||||
|
||||
Assumptions (resume is only sample-exact when they hold):
|
||||
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
|
||||
many positions a step consumes, so the epoch/offset are wrong if either changed. The
|
||||
caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch.
|
||||
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
|
||||
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
|
||||
the boundary is off.
|
||||
"""
|
||||
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
|
||||
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
|
||||
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
|
||||
return {"epoch": epoch, "start_index": start_index}
|
||||
return len(self.indices)
|
||||
|
||||
@@ -77,6 +77,21 @@ from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def get_feature_names(dataset: LeRobotDataset, key: str) -> list[str]:
|
||||
"""Return per-dimension names for a feature from the dataset metadata.
|
||||
|
||||
Only flat-list ``names`` metadata is used. Dict-style ``names`` and missing names fall back to ``{key}_{i}`` indices.
|
||||
"""
|
||||
feature = dataset.features[key]
|
||||
dim = feature["shape"][-1]
|
||||
|
||||
names = feature.get("names")
|
||||
if isinstance(names, list) and len(names) == dim:
|
||||
return [str(name) for name in names]
|
||||
|
||||
return [f"{key}_{d}" for d in range(dim)]
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
@@ -86,6 +101,31 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
return hwc_uint8_numpy
|
||||
|
||||
|
||||
def build_blueprint_from_dataset(dataset: LeRobotDataset):
|
||||
"""Build a Rerun blueprint laying out camera images and time series for the given dataset.
|
||||
|
||||
Camera images and scalar signals (action, state, reward, done, success) are arranged in a grid.
|
||||
The per-dimension series names for ``action`` and ``state`` are applied directly
|
||||
via blueprint overrides.
|
||||
"""
|
||||
import rerun as rr
|
||||
import rerun.blueprint as rrb
|
||||
|
||||
views = [rrb.Spatial2DView(origin=key, name=key) for key in dataset.meta.camera_keys]
|
||||
|
||||
# Style multi-dimensional signals (action, state) with per-dimension names.
|
||||
for origin, key in ((ACTION, ACTION), ("state", OBS_STATE)):
|
||||
if key in dataset.features:
|
||||
names = get_feature_names(dataset, key)
|
||||
styling = rr.SeriesLines(names=names)
|
||||
views.append(rrb.TimeSeriesView(origin=origin, name=origin, overrides={origin: styling}))
|
||||
for key in (DONE, REWARD, "next.success"):
|
||||
if key in dataset.features:
|
||||
views.append(rrb.TimeSeriesView(origin=key, name=key))
|
||||
|
||||
return rrb.Blueprint(rrb.Grid(*views))
|
||||
|
||||
|
||||
def visualize_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
@@ -124,7 +164,8 @@ def visualize_dataset(
|
||||
import rerun as rr
|
||||
|
||||
spawn_local_viewer = mode == "local" and not save
|
||||
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
||||
blueprint = build_blueprint_from_dataset(dataset)
|
||||
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer, default_blueprint=blueprint)
|
||||
|
||||
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
|
||||
# when iterating on a dataloader with `num_workers` > 0
|
||||
@@ -142,26 +183,21 @@ def visualize_dataset(
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
if first_index is None:
|
||||
first_index = batch["index"][0].item()
|
||||
# iterate over the batch
|
||||
|
||||
for i in range(len(batch["index"])):
|
||||
rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index)
|
||||
rr.set_time("timestamp", timestamp=batch["timestamp"][i].item())
|
||||
|
||||
# display each camera image
|
||||
for key in dataset.meta.camera_keys:
|
||||
img = to_hwc_uint8_numpy(batch[key][i])
|
||||
img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img)
|
||||
rr.log(key, entity=img_entity)
|
||||
|
||||
# display each dimension of action space (e.g. actuators command)
|
||||
if ACTION in batch:
|
||||
for dim_idx, val in enumerate(batch[ACTION][i]):
|
||||
rr.log(f"{ACTION}/{dim_idx}", rr.Scalars(val.item()))
|
||||
rr.log(ACTION, rr.Scalars(batch[ACTION][i].numpy()))
|
||||
|
||||
# display each dimension of observed state space (e.g. agent position in joint space)
|
||||
if OBS_STATE in batch:
|
||||
for dim_idx, val in enumerate(batch[OBS_STATE][i]):
|
||||
rr.log(f"state/{dim_idx}", rr.Scalars(val.item()))
|
||||
rr.log("state", rr.Scalars(batch[OBS_STATE][i].numpy()))
|
||||
|
||||
if DONE in batch:
|
||||
rr.log(DONE, rr.Scalars(batch[DONE][i].item()))
|
||||
@@ -173,8 +209,6 @@ def visualize_dataset(
|
||||
rr.log("next.success", rr.Scalars(batch["next.success"][i].item()))
|
||||
|
||||
if mode == "local" and save:
|
||||
# save .rrd locally
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
repo_id_str = repo_id.replace("/", "_")
|
||||
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
|
||||
@@ -182,7 +216,7 @@ def visualize_dataset(
|
||||
return rrd_path
|
||||
|
||||
elif mode == "distant":
|
||||
# stop the process from exiting since it is serving the websocket connection
|
||||
# Keep the process alive while it serves the gRPC/web connection.
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
@@ -297,12 +331,14 @@ def main():
|
||||
)
|
||||
logging.warning("Setting grpc_port to ws_port value.")
|
||||
kwargs["grpc_port"] = kwargs.pop("ws_port")
|
||||
else:
|
||||
kwargs.pop("ws_port") # Always remove ws_port from kwargs
|
||||
|
||||
init_logging()
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
|
||||
|
||||
visualize_dataset(dataset, **vars(args))
|
||||
visualize_dataset(dataset, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -36,8 +36,6 @@ from tqdm import tqdm
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_batch_size,
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
@@ -45,7 +43,7 @@ from lerobot.common.train_utils import (
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
|
||||
from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
@@ -234,16 +232,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: the global main process downloads once to the shared
|
||||
# dataset root, then a barrier lets every other rank read the already-populated copy.
|
||||
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Other ranks read from the shared copy populated by the main process.
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
@@ -388,47 +384,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if not cfg.dataset.streaming:
|
||||
# All non-streaming (map-style) datasets use EpisodeAwareSampler.
|
||||
# The order is a pure function of (seed, epoch), so every rank independently produces the
|
||||
# same permutation. accelerate then shards it disjointly across ranks via BatchSamplerShard
|
||||
# without needing a `generator` attribute to synchronize an RNG, and resume is sample-exact.
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
seed=cfg.seed if cfg.seed is not None else 0,
|
||||
)
|
||||
if cfg.resume and step > 0:
|
||||
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
|
||||
# use the values recorded in the checkpoint (falling back to the current ones for older
|
||||
# ckpts that did not store them).
|
||||
saved_num_processes = load_training_num_processes(cfg.checkpoint_path)
|
||||
saved_batch_size = load_training_batch_size(cfg.checkpoint_path)
|
||||
ckpt_num_processes = saved_num_processes or accelerator.num_processes
|
||||
ckpt_batch_size = saved_batch_size or cfg.batch_size
|
||||
if is_main_process and saved_num_processes not in (None, accelerator.num_processes):
|
||||
logging.warning(
|
||||
f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was "
|
||||
f"written with num_processes={saved_num_processes}. The data order resumes at the "
|
||||
"right epoch/offset, but per-rank sample-exactness requires the same world size."
|
||||
)
|
||||
if is_main_process and saved_batch_size not in (None, cfg.batch_size):
|
||||
logging.warning(
|
||||
f"Resuming with batch_size={cfg.batch_size} but the checkpoint was written with "
|
||||
f"batch_size={saved_batch_size}. The data order resumes at the right epoch/offset, "
|
||||
"but per-rank sample-exactness requires the same batch size."
|
||||
)
|
||||
sampler_state = compute_sampler_state(step, len(sampler), ckpt_batch_size, ckpt_num_processes)
|
||||
sampler.load_state_dict(sampler_state)
|
||||
if is_main_process:
|
||||
logging.info(
|
||||
f"Resuming data order at epoch {sampler_state['epoch']}, "
|
||||
f"sample {sampler_state['start_index']}"
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
@@ -547,8 +511,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
scheduler=lr_scheduler,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
num_processes=accelerator.num_processes,
|
||||
batch_size=cfg.batch_size,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
|
||||
@@ -38,6 +38,8 @@ def init_rerun(
|
||||
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
||||
import rerun as rr
|
||||
|
||||
log_rerun_data.blueprint = None # Reset blueprint cache for new session
|
||||
|
||||
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
|
||||
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
|
||||
rr.init(session_name)
|
||||
@@ -63,6 +65,38 @@ def _is_scalar(x):
|
||||
)
|
||||
|
||||
|
||||
def _build_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]):
|
||||
"""Build a Rerun blueprint laying out camera images, observation and action scalars in separate views.
|
||||
|
||||
Camera images, observation and action scalars are arranged in a grid.
|
||||
"""
|
||||
|
||||
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
|
||||
import rerun.blueprint as rrb
|
||||
|
||||
views = [rrb.Spatial2DView(origin=path, name=path) for path in sorted(image_paths)]
|
||||
|
||||
if observation_paths:
|
||||
views.append(rrb.TimeSeriesView(name="observation", contents=sorted(observation_paths)))
|
||||
if action_paths:
|
||||
views.append(rrb.TimeSeriesView(name="action", contents=sorted(action_paths)))
|
||||
|
||||
return rrb.Blueprint(rrb.Grid(*views))
|
||||
|
||||
|
||||
def _ensure_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]) -> None:
|
||||
"""Build and send the blueprint once, from the first observation and action data."""
|
||||
if getattr(log_rerun_data, "blueprint", None) is not None:
|
||||
return
|
||||
|
||||
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
|
||||
import rerun as rr
|
||||
|
||||
blueprint = _build_blueprint(observation_paths, action_paths, image_paths)
|
||||
log_rerun_data.blueprint = blueprint
|
||||
rr.send_blueprint(blueprint)
|
||||
|
||||
|
||||
def log_rerun_data(
|
||||
observation: RobotObservation | None = None,
|
||||
action: RobotAction | None = None,
|
||||
@@ -76,11 +110,15 @@ def log_rerun_data(
|
||||
- Scalars values (floats, ints) are logged as `rr.Scalars`.
|
||||
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
|
||||
from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`.
|
||||
- 1D NumPy arrays are logged as a series of individual scalars, with each element indexed.
|
||||
- Other multi-dimensional arrays are flattened and logged as individual scalars.
|
||||
- 1D NumPy arrays are logged as a single `rr.Scalars` batch under one entity path, so that every
|
||||
dimension shares the same view instead of being split across one view per element.
|
||||
- Multi-dimensional **action** arrays are flattened and logged as a single `rr.Scalars` batch.
|
||||
|
||||
Keys are automatically namespaced with "observation." or "action." if not already present.
|
||||
|
||||
On the first call, a blueprint is built and sent so observation and action scalars get separate
|
||||
time-series views and each image gets its own spatial view.
|
||||
|
||||
Args:
|
||||
observation: An optional dictionary containing observation data to log.
|
||||
action: An optional dictionary containing action data to log.
|
||||
@@ -90,6 +128,10 @@ def log_rerun_data(
|
||||
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
||||
import rerun as rr
|
||||
|
||||
observation_paths: set[str] = set()
|
||||
action_paths: set[str] = set()
|
||||
image_paths: set[str] = set()
|
||||
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
if v is None:
|
||||
@@ -98,17 +140,19 @@ def log_rerun_data(
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalars(float(v)))
|
||||
observation_paths.add(key)
|
||||
elif isinstance(v, np.ndarray):
|
||||
arr = v
|
||||
# Convert CHW -> HWC when needed
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
if arr.ndim == 1:
|
||||
for i, vi in enumerate(arr):
|
||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
rr.log(key, rr.Scalars(arr.astype(float)))
|
||||
observation_paths.add(key)
|
||||
else:
|
||||
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
|
||||
rr.log(key, entity=img_entity, static=True)
|
||||
image_paths.add(key)
|
||||
|
||||
if action:
|
||||
for k, v in action.items():
|
||||
@@ -118,12 +162,9 @@ def log_rerun_data(
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalars(float(v)))
|
||||
action_paths.add(key)
|
||||
elif isinstance(v, np.ndarray):
|
||||
if v.ndim == 1:
|
||||
for i, vi in enumerate(v):
|
||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
else:
|
||||
# Fall back to flattening higher-dimensional arrays
|
||||
flat = v.flatten()
|
||||
for i, vi in enumerate(flat):
|
||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
rr.log(key, rr.Scalars(v.reshape(-1).astype(float)))
|
||||
action_paths.add(key)
|
||||
|
||||
_ensure_blueprint(observation_paths, action_paths, image_paths)
|
||||
|
||||
@@ -114,19 +114,6 @@ def test_shuffle():
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_shuffle_is_reproducible_across_instances():
|
||||
# The order is a pure function of (seed, epoch), so two fresh samplers (e.g. two ranks)
|
||||
# produce the same permutation without any generator synchronization.
|
||||
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||
epoch_0 = list(sampler_a)
|
||||
assert list(sampler_b) == epoch_0
|
||||
# Desyncing the global RNG must not affect the permutation.
|
||||
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
||||
assert list(sampler_c) == epoch_0
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
@@ -150,87 +137,3 @@ def test_partial_episode_drop_warns(caplog):
|
||||
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
||||
assert sampler.indices == [2, 3, 4, 5]
|
||||
assert "Episode 0" in caplog.text
|
||||
|
||||
|
||||
# --- seeded (seed, epoch) shuffling, resume, and state ---
|
||||
|
||||
from lerobot.datasets.sampler import compute_sampler_state # noqa: E402
|
||||
|
||||
EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100])
|
||||
def test_deterministic_sampler_shuffle_is_permutation(num_frames):
|
||||
for seed in (0, 1, 1234):
|
||||
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed)
|
||||
assert sorted(sampler) == list(range(num_frames))
|
||||
|
||||
|
||||
def test_deterministic_sampler_epochs_reproduce_and_differ():
|
||||
sampler_a = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
|
||||
sampler_b = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
|
||||
epoch_0 = list(sampler_a)
|
||||
assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process
|
||||
epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch
|
||||
assert epoch_1 != epoch_0
|
||||
assert sorted(epoch_1) == sorted(epoch_0)
|
||||
sampler_a.set_epoch(0)
|
||||
assert list(sampler_a) == epoch_0
|
||||
assert list(EpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0
|
||||
|
||||
|
||||
def test_deterministic_sampler_resume_mid_epoch():
|
||||
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||
epoch_0 = list(reference)
|
||||
epoch_1 = list(reference)
|
||||
for start in (0, 1, 4, len(epoch_0)):
|
||||
resumed = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||
resumed.load_state_dict({"epoch": 0, "start_index": start})
|
||||
assert list(resumed) == epoch_0[start:]
|
||||
# the resumed sampler continues into the same epoch 1 as the uninterrupted one
|
||||
assert list(resumed) == epoch_1
|
||||
|
||||
|
||||
def test_deterministic_sampler_construction_stores_only_boundaries():
|
||||
# Construction is O(num_episodes), not O(num_frames): a million-frame single episode
|
||||
# instantiates from just its boundaries without materializing a per-frame index list.
|
||||
num_frames = 1_000_000
|
||||
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||
assert len(sampler) == num_frames
|
||||
assert sampler._starts.shape == (1,) and sampler._cum_lengths.shape == (1,)
|
||||
|
||||
|
||||
def test_deterministic_sampler_resume_is_exact_at_scale():
|
||||
# Seeded randperm makes resume sample-exact at non-trivial sizes: regenerating the epoch's
|
||||
# permutation and slicing from the saved offset reproduces the remaining order exactly.
|
||||
num_frames = 100_000
|
||||
reference = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||
epoch_0 = list(reference)
|
||||
assert sorted(epoch_0) == list(range(num_frames))
|
||||
start = num_frames - 5
|
||||
resumed = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||
resumed.load_state_dict({"epoch": 0, "start_index": start})
|
||||
assert list(resumed) == epoch_0[start:]
|
||||
|
||||
|
||||
def test_compute_sampler_state():
|
||||
# 100 frames, batch 10, 2 ranks -> 10 underlying batches, 5 per rank per epoch.
|
||||
assert compute_sampler_state(step=0, num_frames=100, batch_size=10, num_processes=2) == {
|
||||
"epoch": 0,
|
||||
"start_index": 0,
|
||||
}
|
||||
# step 7 -> epoch 1, 2 per-rank batches in = 2 * 10 * 2 = 40 samples in
|
||||
assert compute_sampler_state(step=7, num_frames=100, batch_size=10, num_processes=2) == {
|
||||
"epoch": 1,
|
||||
"start_index": 40,
|
||||
}
|
||||
# uneven epoch: 95 frames -> 10 underlying batches (last short), still 5 per rank
|
||||
assert compute_sampler_state(step=12, num_frames=95, batch_size=10, num_processes=2) == {
|
||||
"epoch": 2,
|
||||
"start_index": 40,
|
||||
}
|
||||
# uneven sharding: 105 frames -> 11 underlying batches, 6 per rank (even_batches pads)
|
||||
assert compute_sampler_state(step=11, num_frames=105, batch_size=10, num_processes=2) == {
|
||||
"epoch": 1,
|
||||
"start_index": 100,
|
||||
}
|
||||
|
||||
@@ -20,8 +20,6 @@ from unittest.mock import Mock, patch
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_batch_size,
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
save_checkpoint,
|
||||
@@ -65,28 +63,6 @@ def test_load_training_step(tmp_path):
|
||||
assert loaded_step == step
|
||||
|
||||
|
||||
def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler):
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4)
|
||||
assert load_training_num_processes(tmp_path) == 4
|
||||
|
||||
|
||||
def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler):
|
||||
# Checkpoints written before the world size was recorded must still load (back-compat).
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
assert load_training_num_processes(tmp_path) is None
|
||||
|
||||
|
||||
def test_save_training_state_records_batch_size(tmp_path, optimizer, scheduler):
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler, batch_size=32)
|
||||
assert load_training_batch_size(tmp_path) == 32
|
||||
|
||||
|
||||
def test_load_training_batch_size_absent_returns_none(tmp_path, optimizer, scheduler):
|
||||
# Checkpoints written before the batch size was recorded must still load (back-compat).
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
assert load_training_batch_size(tmp_path) is None
|
||||
|
||||
|
||||
def test_update_last_checkpoint(tmp_path):
|
||||
checkpoint = tmp_path / "0005"
|
||||
checkpoint.mkdir()
|
||||
|
||||
@@ -30,25 +30,46 @@ from lerobot.utils.constants import OBS_STATE
|
||||
@pytest.fixture
|
||||
def mock_rerun(monkeypatch):
|
||||
"""
|
||||
Provide a mock `rerun` module so tests don't depend on the real library.
|
||||
Also reload the module-under-test so it binds to this mock `rr`.
|
||||
Provide a mock `rerun` module (and `rerun.blueprint` submodule) so tests don't
|
||||
depend on the real library. Also reload the module-under-test so it binds to
|
||||
this mock `rr`.
|
||||
"""
|
||||
calls = []
|
||||
blueprints = []
|
||||
|
||||
class DummyScalar:
|
||||
def __init__(self, value):
|
||||
self.value = float(value)
|
||||
# Scalars may be built from a single float or from a 1D array batch.
|
||||
self.value = value
|
||||
|
||||
class DummyImage:
|
||||
def __init__(self, arr):
|
||||
self.arr = arr
|
||||
|
||||
def compress(self, *a, **k):
|
||||
return self
|
||||
|
||||
def dummy_log(key, obj=None, **kwargs):
|
||||
# Accept either positional `obj` or keyword `entity` and record remaining kwargs.
|
||||
if obj is None and "entity" in kwargs:
|
||||
obj = kwargs.pop("entity")
|
||||
calls.append((key, obj, kwargs))
|
||||
|
||||
def dummy_send_blueprint(blueprint, *a, **k):
|
||||
blueprints.append(blueprint)
|
||||
|
||||
# Mock the `rerun.blueprint` submodule used to build the layout.
|
||||
dummy_rrb = SimpleNamespace(
|
||||
Spatial2DView=lambda origin=None, name=None: SimpleNamespace(
|
||||
kind="Spatial2DView", origin=origin, name=name
|
||||
),
|
||||
TimeSeriesView=lambda name=None, contents=None: SimpleNamespace(
|
||||
kind="TimeSeriesView", name=name, contents=contents
|
||||
),
|
||||
Grid=lambda *views: SimpleNamespace(kind="Grid", views=list(views)),
|
||||
Blueprint=lambda root: SimpleNamespace(kind="Blueprint", root=root),
|
||||
)
|
||||
|
||||
dummy_rr = SimpleNamespace(
|
||||
__name__="rerun",
|
||||
__package__="rerun",
|
||||
@@ -56,20 +77,23 @@ def mock_rerun(monkeypatch):
|
||||
Scalars=DummyScalar,
|
||||
Image=DummyImage,
|
||||
log=dummy_log,
|
||||
send_blueprint=dummy_send_blueprint,
|
||||
init=lambda *a, **k: None,
|
||||
spawn=lambda *a, **k: None,
|
||||
blueprint=dummy_rrb,
|
||||
)
|
||||
|
||||
# Inject fake module into sys.modules
|
||||
# Inject fake modules into sys.modules (both `rerun` and `rerun.blueprint`).
|
||||
monkeypatch.setitem(sys.modules, "rerun", dummy_rr)
|
||||
monkeypatch.setitem(sys.modules, "rerun.blueprint", dummy_rrb)
|
||||
|
||||
# Now import and reload the module under test, to bind to our rerun mock
|
||||
import lerobot.utils.visualization_utils as vu
|
||||
|
||||
importlib.reload(vu)
|
||||
|
||||
# Expose both the reloaded module and the call recorder
|
||||
yield vu, calls
|
||||
# Expose the reloaded module, the call recorder and the captured blueprints
|
||||
yield vu, calls, blueprints
|
||||
|
||||
|
||||
def _keys(calls):
|
||||
@@ -92,8 +116,13 @@ def _kwargs_for(calls, key):
|
||||
raise KeyError(f"Key {key} not found in calls: {calls}")
|
||||
|
||||
|
||||
def _views_by_kind(blueprint, kind):
|
||||
"""Return the views of a given kind from the (single) blueprint's grid."""
|
||||
return [v for v in blueprint.root.views if v.kind == kind]
|
||||
|
||||
|
||||
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
vu, calls, blueprints = mock_rerun
|
||||
|
||||
# Build EnvTransition dict
|
||||
obs = {
|
||||
@@ -103,7 +132,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
}
|
||||
act = {
|
||||
"action.throttle": 0.7,
|
||||
# 1D array should log individual Scalars with suffix _i
|
||||
# 1D array should be logged as a single Scalars batch under one entity path
|
||||
"action.vector": np.array([1.0, 2.0], dtype=np.float32),
|
||||
}
|
||||
transition = {
|
||||
@@ -120,31 +149,28 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
# - observation.state.temperature -> Scalars
|
||||
# - observation.camera -> Image (HWC) with static=True
|
||||
# - action.throttle -> Scalars
|
||||
# - action.vector_0, action.vector_1 -> Scalars
|
||||
# - action.vector -> single Scalars batch (no per-element suffix)
|
||||
expected_keys = {
|
||||
f"{OBS_STATE}.temperature",
|
||||
"observation.camera",
|
||||
"action.throttle",
|
||||
"action.vector_0",
|
||||
"action.vector_1",
|
||||
"action.vector",
|
||||
}
|
||||
assert set(_keys(calls)) == expected_keys
|
||||
|
||||
# Check scalar types and values
|
||||
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
|
||||
assert type(temp_obj).__name__ == "DummyScalar"
|
||||
assert temp_obj.value == pytest.approx(25.0)
|
||||
assert float(temp_obj.value) == pytest.approx(25.0)
|
||||
|
||||
throttle_obj = _obj_for(calls, "action.throttle")
|
||||
assert type(throttle_obj).__name__ == "DummyScalar"
|
||||
assert throttle_obj.value == pytest.approx(0.7)
|
||||
assert float(throttle_obj.value) == pytest.approx(0.7)
|
||||
|
||||
v0 = _obj_for(calls, "action.vector_0")
|
||||
v1 = _obj_for(calls, "action.vector_1")
|
||||
assert type(v0).__name__ == "DummyScalar"
|
||||
assert type(v1).__name__ == "DummyScalar"
|
||||
assert v0.value == pytest.approx(1.0)
|
||||
assert v1.value == pytest.approx(2.0)
|
||||
# 1D vector logged as a single batched Scalars under one entity path
|
||||
vec = _obj_for(calls, "action.vector")
|
||||
assert type(vec).__name__ == "DummyScalar"
|
||||
np.testing.assert_allclose(np.asarray(vec.value), [1.0, 2.0])
|
||||
|
||||
# Check image handling: CHW -> HWC
|
||||
img_obj = _obj_for(calls, "observation.camera")
|
||||
@@ -152,9 +178,24 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
assert img_obj.arr.shape == (10, 20, 3) # transposed
|
||||
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
|
||||
|
||||
# A blueprint should have been built and sent exactly once, and cached on the function.
|
||||
assert len(blueprints) == 1
|
||||
assert vu.log_rerun_data.blueprint is blueprints[0]
|
||||
|
||||
bp = blueprints[0]
|
||||
# One spatial view per image path
|
||||
spatial_views = _views_by_kind(bp, "Spatial2DView")
|
||||
assert {v.origin for v in spatial_views} == {"observation.camera"}
|
||||
|
||||
# One time-series view each for observation and action scalars
|
||||
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
|
||||
assert set(ts_views) == {"observation", "action"}
|
||||
assert ts_views["observation"].contents == [f"{OBS_STATE}.temperature"]
|
||||
assert ts_views["action"].contents == ["action.throttle", "action.vector"]
|
||||
|
||||
|
||||
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
vu, calls, blueprints = mock_rerun
|
||||
|
||||
# First dict without prefixes treated as observation
|
||||
# Second dict without prefixes treated as action
|
||||
@@ -173,14 +214,12 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
# First dict was treated as observation, second as action
|
||||
vu.log_rerun_data(observation=obs_plain, action=act_plain)
|
||||
|
||||
# Expected keys with auto-prefixes
|
||||
# Expected keys with auto-prefixes. The 1D vector is a single batched Scalars.
|
||||
expected = {
|
||||
"observation.temp",
|
||||
"observation.img",
|
||||
"action.throttle",
|
||||
"action.vec_0",
|
||||
"action.vec_1",
|
||||
"action.vec_2",
|
||||
"action.vec",
|
||||
}
|
||||
logged = set(_keys(calls))
|
||||
assert logged == expected
|
||||
@@ -188,11 +227,11 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
# Scalars
|
||||
t = _obj_for(calls, "observation.temp")
|
||||
assert type(t).__name__ == "DummyScalar"
|
||||
assert t.value == pytest.approx(1.5)
|
||||
assert float(t.value) == pytest.approx(1.5)
|
||||
|
||||
throttle = _obj_for(calls, "action.throttle")
|
||||
assert type(throttle).__name__ == "DummyScalar"
|
||||
assert throttle.value == pytest.approx(0.3)
|
||||
assert float(throttle.value) == pytest.approx(0.3)
|
||||
|
||||
# Image stays HWC
|
||||
img = _obj_for(calls, "observation.img")
|
||||
@@ -200,15 +239,23 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
assert img.arr.shape == (5, 6, 3)
|
||||
assert _kwargs_for(calls, "observation.img").get("static", False) is True
|
||||
|
||||
# Vectors
|
||||
for i, val in enumerate([9, 8, 7]):
|
||||
o = _obj_for(calls, f"action.vec_{i}")
|
||||
assert type(o).__name__ == "DummyScalar"
|
||||
assert o.value == pytest.approx(val)
|
||||
# Vector logged as a single batched Scalars under one entity path
|
||||
vec = _obj_for(calls, "action.vec")
|
||||
assert type(vec).__name__ == "DummyScalar"
|
||||
np.testing.assert_allclose(np.asarray(vec.value), [9, 8, 7])
|
||||
|
||||
# Blueprint sent once with the expected view layout
|
||||
assert len(blueprints) == 1
|
||||
bp = blueprints[0]
|
||||
spatial_views = _views_by_kind(bp, "Spatial2DView")
|
||||
assert {v.origin for v in spatial_views} == {"observation.img"}
|
||||
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
|
||||
assert ts_views["observation"].contents == ["observation.temp"]
|
||||
assert ts_views["action"].contents == ["action.throttle", "action.vec"]
|
||||
|
||||
|
||||
def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
vu, calls, blueprints = mock_rerun
|
||||
|
||||
vu.log_rerun_data(
|
||||
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
|
||||
@@ -222,7 +269,7 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
|
||||
temp = _obj_for(calls, "observation.temp")
|
||||
assert type(temp).__name__ == "DummyScalar"
|
||||
assert temp.value == pytest.approx(10.0)
|
||||
assert float(temp.value) == pytest.approx(10.0)
|
||||
|
||||
img = _obj_for(calls, "observation.gray")
|
||||
assert type(img).__name__ == "DummyImage"
|
||||
@@ -231,4 +278,26 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
|
||||
a = _obj_for(calls, "action.a")
|
||||
assert type(a).__name__ == "DummyScalar"
|
||||
assert a.value == pytest.approx(1.0)
|
||||
assert float(a.value) == pytest.approx(1.0)
|
||||
|
||||
# Blueprint sent once, with a spatial view for the image and time-series views for scalars
|
||||
assert len(blueprints) == 1
|
||||
bp = blueprints[0]
|
||||
assert {v.origin for v in _views_by_kind(bp, "Spatial2DView")} == {"observation.gray"}
|
||||
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
|
||||
assert ts_views["observation"].contents == ["observation.temp"]
|
||||
assert ts_views["action"].contents == ["action.a"]
|
||||
|
||||
|
||||
def test_log_rerun_data_blueprint_sent_only_once(mock_rerun):
|
||||
"""The blueprint is built from the first call and not resent on subsequent calls."""
|
||||
vu, calls, blueprints = mock_rerun
|
||||
|
||||
vu.log_rerun_data(observation={"temp": 1.0}, action={"a": 2.0})
|
||||
assert len(blueprints) == 1
|
||||
first_blueprint = vu.log_rerun_data.blueprint
|
||||
|
||||
vu.log_rerun_data(observation={"temp": 3.0}, action={"a": 4.0})
|
||||
# Still only one blueprint, and the cached one is unchanged.
|
||||
assert len(blueprints) == 1
|
||||
assert vu.log_rerun_data.blueprint is first_blueprint
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.12"
|
||||
resolution-markers = [
|
||||
"(python_full_version >= '3.15' and platform_machine == 'AMD64' and sys_platform == 'linux') or (python_full_version >= '3.15' and platform_machine == 'x86_64' and sys_platform == 'linux')",
|
||||
@@ -3257,7 +3257,7 @@ requires-dist = [
|
||||
{ name = "qwen-vl-utils", marker = "extra == 'qwen-vl-utils-dep'", specifier = ">=0.0.11,<0.1.0" },
|
||||
{ name = "reachy2-sdk", marker = "extra == 'reachy2'", specifier = ">=1.0.15,<1.1.0" },
|
||||
{ name = "requests", specifier = ">=2.32.0,<3.0.0" },
|
||||
{ name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.27.0" },
|
||||
{ name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.34.0" },
|
||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.14.1" },
|
||||
{ name = "safetensors", specifier = ">=0.4.3,<1.0.0" },
|
||||
{ name = "scikit-image", marker = "extra == 'video-benchmark'", specifier = ">=0.23.2,<0.26.0" },
|
||||
@@ -5636,21 +5636,21 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "rerun-sdk"
|
||||
version = "0.26.2"
|
||||
version = "0.33.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "attrs" },
|
||||
{ name = "numpy" },
|
||||
{ name = "pillow" },
|
||||
{ name = "psutil" },
|
||||
{ name = "pyarrow" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/4a/767c20e1529d74d9be5b5e55c6c26b63a6918ef3c1709fc422d08a460114/rerun_sdk-0.26.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3d4151c9a3484e112b53d1df90c8fa07397dc7b8bfbb420f09e011eff20f1ef2", size = 93349439, upload-time = "2025-10-27T11:34:10.745Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/3d/d8dd0af9c287a85d51ec99d69406cc4b94a9feb1d6f192d3bbcaac9f0b81/rerun_sdk-0.26.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:03977d2aba4966d9a70b682eca196123fda11408fecd733441ede9916c6341e2", size = 86323042, upload-time = "2025-10-27T11:34:17.995Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/13/29/53d8d98799ab32418fd4ba6834d6a5749c31f56160d3c87f52a7219887e9/rerun_sdk-0.26.2-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:b6128c3c4f014cae5be18e4d37657c5932d1bcdb2ce5e9d4b488a6eed47f7437", size = 92677274, upload-time = "2025-10-27T11:34:22.601Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f5/86/0b9c8f56398b4fc85f8e99279907c258413a297e5603f8f2537fe5806e51/rerun_sdk-0.26.2-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a6f97b60aaa7d4e8c6124a3f6b97ce9dbd09520050955f0e0bdacb72b0eb106a", size = 98768129, upload-time = "2025-10-27T11:34:27.36Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/e7/99fc91c0f99f69d7d43e1db0a6f6cb8273ffc02111539bfc1fee43749bad/rerun_sdk-0.26.2-cp39-abi3-win_amd64.whl", hash = "sha256:a493ad6c8357022cba2ca6f8954a81d0faf984b0b22154eb1d976bfc7649df63", size = 84267089, upload-time = "2025-10-27T11:34:32.023Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/17/5a521e86ac0064bd0f452e3e98e2422433511b54110423c0217d2cc1234f/rerun_sdk-0.33.0-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:97f123e3ef6aa69b60194bc566e5435c7d4040757ed4f58297ea46c8ef320c5c", size = 125707606, upload-time = "2026-05-29T09:42:53.584Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/2f/2ca2599aca03b69fbcac7c8391ef50376968edd7c58b96de53a4b7f20624/rerun_sdk-0.33.0-cp310-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8f734cf59419dcfbc46915bea6cec030224f16e96c3a597f0ccf7cb7b058dd43", size = 135271020, upload-time = "2026-05-29T09:43:00.106Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2e/ba/d70997b43e6db4f58c4326c29c6a6a384ddc6c2fe125f231c885ad9b3b1f/rerun_sdk-0.33.0-cp310-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:53d95609f8b330026bcd041bf6d11b46ee1c18b6fbde155135f291fe86328eeb", size = 139552018, upload-time = "2026-05-29T09:43:06.275Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/a5/0cac294d16aff6c9a2f183f838428a0380b4d2fd9e053bb37b3041999ad5/rerun_sdk-0.33.0-cp310-abi3-win_amd64.whl", hash = "sha256:b152992a72ec240062c8c285bd30ab681b464a25efbe1464c66fdac82320de1f", size = 120418186, upload-time = "2026-05-29T09:43:13.733Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user