mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix multi gpu processor bug
This commit is contained in:
@@ -210,8 +210,21 @@ class DeltaActionsProcessorStep(ProcessorStep):
|
||||
def _build_mask(self, action_dim: int) -> list[bool]:
|
||||
if not self.exclude_joints or self.action_names is None:
|
||||
return [True] * action_dim
|
||||
exclude = set(self.exclude_joints)
|
||||
return [n not in exclude for n in self.action_names]
|
||||
|
||||
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
|
||||
if not exclude_tokens:
|
||||
return [True] * action_dim
|
||||
|
||||
mask = []
|
||||
for name in self.action_names[:action_dim]:
|
||||
action_name = str(name).lower()
|
||||
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
|
||||
mask.append(not is_excluded)
|
||||
|
||||
if len(mask) < action_dim:
|
||||
mask.extend([True] * (action_dim - len(mask)))
|
||||
|
||||
return mask
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||||
|
||||
@@ -211,6 +211,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
delta_action_stats = None
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
@@ -220,13 +221,27 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.compute_stats import get_feature_stats
|
||||
from lerobot.processor.delta_action_processor import to_delta_actions
|
||||
from lerobot.processor.delta_action_processor import DeltaActionsProcessorStep, to_delta_actions
|
||||
|
||||
chunk_size = cfg.policy.chunk_size
|
||||
hf = dataset.hf_dataset
|
||||
total_frames = len(hf)
|
||||
max_samples = min(500_000, total_frames - chunk_size)
|
||||
indices = np.random.choice(total_frames - chunk_size, max_samples, replace=False)
|
||||
sample_upper_bound = total_frames - chunk_size
|
||||
if sample_upper_bound <= 0:
|
||||
raise ValueError(
|
||||
f"Cannot compute delta action stats: total_frames={total_frames}, chunk_size={chunk_size}"
|
||||
)
|
||||
|
||||
max_samples = min(100_000, sample_upper_bound)
|
||||
indices = np.random.choice(sample_upper_bound, max_samples, replace=False)
|
||||
|
||||
action_names = dataset.meta.features.get("action", {}).get("names")
|
||||
delta_mask_step = DeltaActionsProcessorStep(
|
||||
enabled=True,
|
||||
exclude_joints=getattr(cfg.policy, "delta_exclude_joints", []),
|
||||
action_names=action_names,
|
||||
)
|
||||
delta_mask = delta_mask_step._build_mask(dataset.meta.features["action"]["shape"][0])
|
||||
logging.info(
|
||||
f"use_delta_actions is enabled — computing delta action stats "
|
||||
f"from {max_samples} chunk samples (chunk_size={chunk_size})"
|
||||
@@ -245,13 +260,16 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
actions = torch.tensor(np.stack([np.asarray(a) for a in chunk_data["action"]])).float()
|
||||
state = torch.tensor(np.asarray(chunk_data["observation.state"][0])).float()
|
||||
|
||||
mask = [True] * actions.shape[-1]
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), delta_mask).squeeze(0)
|
||||
all_delta_actions.append(delta.numpy())
|
||||
|
||||
if not all_delta_actions:
|
||||
raise RuntimeError("Failed to compute delta action stats: no valid chunks found.")
|
||||
|
||||
all_delta = np.concatenate(all_delta_actions, axis=0)
|
||||
delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
|
||||
dataset.meta.stats["action"] = delta_stats
|
||||
delta_action_stats = delta_stats
|
||||
dataset.meta.stats["action"] = delta_action_stats
|
||||
|
||||
norm_type = "UNKNOWN"
|
||||
if hasattr(cfg.policy, "normalization_mapping"):
|
||||
@@ -259,8 +277,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
action_norm = cfg.policy.normalization_mapping.get("ACTION", None)
|
||||
norm_type = action_norm.value if action_norm else "UNKNOWN"
|
||||
|
||||
excluded_dims = len(delta_mask) - sum(delta_mask)
|
||||
logging.info(
|
||||
f"Delta action stats ({len(all_delta_actions)} chunks, {len(all_delta)} values, norm={norm_type}): "
|
||||
f"delta_dims={sum(delta_mask)}/{len(delta_mask)} (excluded={excluded_dims}), "
|
||||
f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}, "
|
||||
f"q01={delta_stats['q01'].mean():.4f}, q99={delta_stats['q99'].mean():.4f}"
|
||||
)
|
||||
@@ -274,6 +294,15 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
if not is_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Ensure all ranks use the exact same delta action stats.
|
||||
if getattr(cfg.policy, "use_delta_actions", False):
|
||||
if accelerator.num_processes > 1 and torch.distributed.is_initialized():
|
||||
stats_list = [delta_action_stats]
|
||||
torch.distributed.broadcast_object_list(stats_list, src=0)
|
||||
delta_action_stats = stats_list[0]
|
||||
if delta_action_stats is not None:
|
||||
dataset.meta.stats["action"] = delta_action_stats
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
@@ -299,10 +328,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
processor_pretrained_path = cfg.policy.pretrained_path
|
||||
if (
|
||||
getattr(cfg.policy, "use_delta_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
and not cfg.resume
|
||||
):
|
||||
logging.warning(
|
||||
"use_delta_actions=true with pretrained processors can skip delta transforms if "
|
||||
"the checkpoint processors do not define them. Building processors from current policy config."
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
|
||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
@@ -310,7 +351,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
if cfg.policy.type == "sarm":
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if cfg.policy.pretrained_path is not None:
|
||||
if processor_pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
@@ -332,7 +373,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
@@ -450,7 +491,36 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
|
||||
# Debug logging for first few steps and periodically
|
||||
if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)):
|
||||
action = batch.get("action")
|
||||
state = batch.get("observation.state")
|
||||
if action is not None and state is not None:
|
||||
logging.info(
|
||||
f"[DEBUG step={step}] PRE-PROCESSOR — "
|
||||
f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, "
|
||||
f"min={action.min():.4f}, max={action.max():.4f} | "
|
||||
f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}"
|
||||
)
|
||||
|
||||
batch = preprocessor(batch)
|
||||
|
||||
if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)):
|
||||
action = batch.get("action")
|
||||
state = batch.get("observation.state")
|
||||
if action is not None:
|
||||
logging.info(
|
||||
f"[DEBUG step={step}] POST-PROCESSOR — "
|
||||
f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, "
|
||||
f"min={action.min():.4f}, max={action.max():.4f}"
|
||||
)
|
||||
if state is not None:
|
||||
logging.info(
|
||||
f"[DEBUG step={step}] POST-PROCESSOR — "
|
||||
f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}, std={state.std():.4f}"
|
||||
)
|
||||
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
|
||||
@@ -1,124 +1,344 @@
|
||||
"""Tests for delta action transforms using a local dummy dataset."""
|
||||
"""Tests for delta action transforms — full pipeline validation.
|
||||
|
||||
Tests the complete flow matching OpenPI:
|
||||
raw actions → DeltaActions → Normalize(delta_stats) → model → Unnormalize → AbsoluteActions
|
||||
|
||||
Uses real dataset: lerobot-data-collection/dagger_final_1_21
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.compute_stats import get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor import TransitionKey, batch_to_transition
|
||||
from lerobot.processor.delta_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
to_absolute_actions,
|
||||
to_delta_actions,
|
||||
)
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
ACTION_DIM = 14
|
||||
STATE_DIM = 14
|
||||
CHUNK_SIZE = 10
|
||||
REPO_ID = "lerobot-data-collection/dagger_final_1_21"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (ACTION_DIM,), "names": None},
|
||||
"observation.state": {"dtype": "float32", "shape": (STATE_DIM,), "names": None},
|
||||
}
|
||||
ds = empty_lerobot_dataset_factory(root=tmp_path / "delta_test", features=features)
|
||||
for ep in range(2):
|
||||
for _ in range(5):
|
||||
ds.add_frame(
|
||||
{
|
||||
"action": np.random.randn(ACTION_DIM).astype(np.float32),
|
||||
"observation.state": np.random.randn(STATE_DIM).astype(np.float32),
|
||||
"task": f"task_{ep}",
|
||||
}
|
||||
)
|
||||
ds.save_episode()
|
||||
ds.finalize()
|
||||
return ds
|
||||
@pytest.fixture(scope="module")
|
||||
def dataset():
|
||||
return LeRobotDataset(REPO_ID, episodes=[0])
|
||||
|
||||
|
||||
def _collate(dataset, indices):
|
||||
items = [dataset[i] for i in indices]
|
||||
batch = {}
|
||||
for key in items[0]:
|
||||
vals = [item[key] for item in items]
|
||||
if isinstance(vals[0], torch.Tensor):
|
||||
batch[key] = torch.stack(vals)
|
||||
else:
|
||||
batch[key] = vals
|
||||
return batch
|
||||
@pytest.fixture(scope="module")
|
||||
def action_dim(dataset):
|
||||
return dataset.meta.features["action"]["shape"][0]
|
||||
|
||||
|
||||
def test_roundtrip_3d(dataset):
|
||||
"""Delta then absolute on real data should recover original actions."""
|
||||
batch = _collate(dataset, range(4))
|
||||
actions = batch[ACTION].unsqueeze(1).expand(-1, 10, -1).clone()
|
||||
state = batch[OBS_STATE]
|
||||
mask = [True] * actions.shape[-1]
|
||||
def _build_action_chunks(dataset, chunk_size, max_chunks=50):
|
||||
"""Build action chunks from hf_dataset, like the training script does."""
|
||||
hf = dataset.hf_dataset
|
||||
total = len(hf)
|
||||
all_ep = torch.tensor([int(hf[i]["episode_index"]) for i in range(total)])
|
||||
chunks, states = [], []
|
||||
for i in range(total - chunk_size + 1):
|
||||
if all_ep[i] != all_ep[i + chunk_size - 1]:
|
||||
continue
|
||||
chunk_actions = torch.stack([hf[i + k]["action"] for k in range(chunk_size)]).float()
|
||||
state = hf[i]["observation.state"].float()
|
||||
chunks.append(chunk_actions)
|
||||
states.append(state)
|
||||
if len(chunks) >= max_chunks:
|
||||
break
|
||||
assert len(chunks) > 0, f"No valid chunks found. total={total}, ep_indices={all_ep.tolist()}"
|
||||
return torch.stack(chunks), torch.stack(states)
|
||||
|
||||
delta = to_delta_actions(actions, state, mask)
|
||||
recovered = to_absolute_actions(delta, state, mask)
|
||||
|
||||
def _compute_delta_chunk_stats(action_chunks, states, mask):
|
||||
all_deltas = []
|
||||
for actions, state in zip(action_chunks, states):
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
all_deltas.append(delta.numpy())
|
||||
all_delta = np.concatenate(all_deltas, axis=0)
|
||||
return get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
|
||||
|
||||
|
||||
# --- Basic roundtrip tests ---
|
||||
|
||||
def test_roundtrip_3d(action_dim):
|
||||
actions = torch.randn(4, CHUNK_SIZE, action_dim)
|
||||
state = torch.randn(4, action_dim)
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask)
|
||||
torch.testing.assert_close(recovered, actions)
|
||||
|
||||
|
||||
def test_roundtrip_2d(dataset):
|
||||
"""Works with (B, action_dim) shaped actions too."""
|
||||
batch = _collate(dataset, range(4))
|
||||
actions = batch[ACTION]
|
||||
state = batch[OBS_STATE]
|
||||
mask = [True] * actions.shape[-1]
|
||||
|
||||
delta = to_delta_actions(actions, state, mask)
|
||||
recovered = to_absolute_actions(delta, state, mask)
|
||||
def test_roundtrip_2d(action_dim):
|
||||
actions = torch.randn(4, action_dim)
|
||||
state = torch.randn(4, action_dim)
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask)
|
||||
torch.testing.assert_close(recovered, actions)
|
||||
|
||||
|
||||
def test_delta_changes_all_dims(dataset):
|
||||
"""All dims should change when mask is all True."""
|
||||
batch = _collate(dataset, range(4))
|
||||
actions = batch[ACTION].unsqueeze(1)
|
||||
state = batch[OBS_STATE]
|
||||
mask = [True] * actions.shape[-1]
|
||||
|
||||
delta = to_delta_actions(actions, state, mask)
|
||||
assert (delta - actions).abs().sum() > 0
|
||||
|
||||
|
||||
def test_no_mutation(dataset):
|
||||
"""Original tensors should not be modified."""
|
||||
batch = _collate(dataset, range(2))
|
||||
actions = batch[ACTION].unsqueeze(1)
|
||||
def test_no_mutation(action_dim):
|
||||
actions = torch.randn(2, CHUNK_SIZE, action_dim)
|
||||
original = actions.clone()
|
||||
state = batch[OBS_STATE]
|
||||
mask = [True] * actions.shape[-1]
|
||||
|
||||
to_delta_actions(actions, state, mask)
|
||||
state = torch.randn(2, action_dim)
|
||||
to_delta_actions(actions, state, [True] * action_dim)
|
||||
torch.testing.assert_close(actions, original)
|
||||
|
||||
|
||||
def test_processor_step_roundtrip(dataset):
|
||||
def test_exclude_joints_supports_partial_name_matching():
|
||||
names = [
|
||||
"right_joint_1.pos",
|
||||
"right_gripper.pos",
|
||||
"left_joint_1.pos",
|
||||
"left_gripper.pos",
|
||||
]
|
||||
step = DeltaActionsProcessorStep(enabled=True, exclude_joints=["gripper"], action_names=names)
|
||||
assert step._build_mask(len(names)) == [True, False, True, False]
|
||||
|
||||
|
||||
# --- Chunk-level delta stats test ---
|
||||
|
||||
def test_chunk_stats_have_larger_std_than_frame_stats(dataset, action_dim):
|
||||
"""Chunk-level delta stats should have larger std than per-frame delta stats."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
chunk_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
|
||||
# Per-frame stats
|
||||
hf = dataset.hf_dataset
|
||||
n = min(500, len(hf))
|
||||
frame_actions = torch.stack([hf[i]["action"] for i in range(n)]).float()
|
||||
frame_states = torch.stack([hf[i]["observation.state"] for i in range(n)]).float()
|
||||
frame_deltas = to_delta_actions(frame_actions, frame_states, mask).numpy()
|
||||
frame_stats = get_feature_stats(frame_deltas, axis=0, keepdims=frame_deltas.ndim == 1)
|
||||
|
||||
assert chunk_stats["std"].mean() >= frame_stats["std"].mean(), (
|
||||
f"Chunk std ({chunk_stats['std'].mean():.4f}) should be >= "
|
||||
f"frame std ({frame_stats['std'].mean():.4f})"
|
||||
)
|
||||
|
||||
|
||||
# --- Full pipeline roundtrip: delta → normalize → unnormalize → absolute ---
|
||||
|
||||
def test_full_pipeline_roundtrip(dataset, action_dim):
|
||||
"""Test the complete OpenPI pipeline: delta → normalize → unnormalize → absolute."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
stats = {ACTION: {k: v for k, v in delta_stats.items()}}
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
|
||||
delta_step = DeltaActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step)
|
||||
|
||||
original_actions = action_chunks[0].unsqueeze(0)
|
||||
state = states[0].unsqueeze(0)
|
||||
|
||||
batch = {ACTION: original_actions, OBS_STATE: state}
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Forward: delta → normalize
|
||||
t1 = delta_step(transition)
|
||||
t2 = normalizer(t1)
|
||||
|
||||
normalized_action = t2[TransitionKey.ACTION]
|
||||
assert normalized_action.abs().mean() < 10, (
|
||||
f"Normalized actions should be in reasonable range, got mean abs {normalized_action.abs().mean():.2f}"
|
||||
)
|
||||
|
||||
# Reverse: unnormalize → absolute
|
||||
t3 = unnormalizer(t2)
|
||||
t4 = absolute_step(t3)
|
||||
|
||||
recovered_actions = t4[TransitionKey.ACTION]
|
||||
torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
def test_normalized_delta_values_are_reasonable(dataset, action_dim):
|
||||
"""With correct chunk stats, normalized delta actions should be in a reasonable range."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
mean = torch.tensor(delta_stats["mean"]).float()
|
||||
std = torch.tensor(delta_stats["std"]).float()
|
||||
|
||||
all_normalized = []
|
||||
for actions, state in zip(action_chunks, states):
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
normalized = (delta - mean) / (std + 1e-6)
|
||||
all_normalized.append(normalized)
|
||||
|
||||
all_normalized = torch.cat(all_normalized, dim=0)
|
||||
|
||||
pct_in_range = (all_normalized.abs() < 5).float().mean()
|
||||
assert pct_in_range > 0.9, (
|
||||
f"Only {pct_in_range*100:.1f}% of normalized values in [-5, 5], expected >90%"
|
||||
)
|
||||
|
||||
assert all_normalized.mean().abs() < 1.0, (
|
||||
f"Mean of normalized deltas is {all_normalized.mean():.2f}, expected near 0"
|
||||
)
|
||||
|
||||
|
||||
def test_processor_step_roundtrip(dataset, action_dim):
|
||||
"""DeltaActionsProcessorStep applies delta; to_absolute_actions recovers original."""
|
||||
batch = _collate(dataset, range(4))
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
|
||||
}
|
||||
original_actions = batch[ACTION].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
step = DeltaActionsProcessorStep(enabled=True)
|
||||
delta_transition = step(transition)
|
||||
|
||||
delta_actions = delta_transition[TransitionKey.ACTION]
|
||||
assert not torch.allclose(delta_actions, original_actions)
|
||||
assert not torch.allclose(delta_transition[TransitionKey.ACTION], original_actions)
|
||||
|
||||
state = transition[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
mask = [True] * original_actions.shape[-1]
|
||||
recovered = to_absolute_actions(delta_actions, state, mask)
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(delta_transition[TransitionKey.ACTION], state, mask)
|
||||
torch.testing.assert_close(recovered, original_actions)
|
||||
|
||||
|
||||
def test_processor_step_disabled_is_noop(dataset):
|
||||
def test_processor_step_disabled_is_noop(dataset, action_dim):
|
||||
"""enabled=False should be a no-op."""
|
||||
batch = _collate(dataset, range(2))
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(2)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(2)]),
|
||||
}
|
||||
original = batch[ACTION].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
result = DeltaActionsProcessorStep(enabled=False)(transition)
|
||||
torch.testing.assert_close(result[TransitionKey.ACTION], original)
|
||||
|
||||
|
||||
# --- Training batch shape validation ---
|
||||
|
||||
def test_delta_with_action_chunks(dataset, action_dim):
|
||||
"""Verify delta works correctly with (B, chunk_size, action_dim) shaped actions."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
|
||||
# Simulate a training batch: actions=(B, chunk_size, action_dim), state=(B, state_dim)
|
||||
batch_actions = action_chunks[:4] # (4, chunk_size, action_dim)
|
||||
batch_states = states[:4] # (4, state_dim)
|
||||
|
||||
mask = [True] * action_dim
|
||||
delta = to_delta_actions(batch_actions, batch_states, mask)
|
||||
|
||||
# First action in each chunk should be close to zero (action[t] - state[t] ≈ small)
|
||||
first_deltas = delta[:, 0, :] # (B, action_dim)
|
||||
assert first_deltas.abs().mean() < delta.abs().mean(), (
|
||||
f"First action in chunk should have smaller delta than average. "
|
||||
f"First: {first_deltas.abs().mean():.4f}, Average: {delta.abs().mean():.4f}"
|
||||
)
|
||||
|
||||
# Later actions should have larger deltas
|
||||
last_deltas = delta[:, -1, :] # (B, action_dim)
|
||||
assert last_deltas.abs().mean() >= first_deltas.abs().mean(), (
|
||||
f"Last action in chunk should have >= delta than first. "
|
||||
f"Last: {last_deltas.abs().mean():.4f}, First: {first_deltas.abs().mean():.4f}"
|
||||
)
|
||||
|
||||
# Roundtrip
|
||||
recovered = to_absolute_actions(delta, batch_states, mask)
|
||||
torch.testing.assert_close(recovered, batch_actions)
|
||||
|
||||
|
||||
def test_delta_stats_match_actual_data_distribution(dataset, action_dim):
|
||||
"""Verify computed stats match the actual delta distribution."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
# Compute stats like the training script does
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
|
||||
# Also compute directly
|
||||
all_deltas = []
|
||||
for actions, state in zip(action_chunks, states):
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
all_deltas.append(delta)
|
||||
all_deltas_tensor = torch.cat(all_deltas, dim=0)
|
||||
|
||||
# Compare mean
|
||||
actual_mean = all_deltas_tensor.mean(dim=0).numpy()
|
||||
np.testing.assert_allclose(delta_stats["mean"], actual_mean, atol=0.01)
|
||||
|
||||
# Compare std
|
||||
actual_std = all_deltas_tensor.std(dim=0).numpy()
|
||||
np.testing.assert_allclose(delta_stats["std"], actual_std, atol=0.1)
|
||||
|
||||
# Verify q01 < mean < q99
|
||||
assert (delta_stats["q01"] < delta_stats["mean"]).all(), "q01 should be < mean"
|
||||
assert (delta_stats["mean"] < delta_stats["q99"]).all(), "mean should be < q99"
|
||||
|
||||
|
||||
def test_quantile_normalization_roundtrip(dataset, action_dim):
|
||||
"""Full roundtrip with QUANTILES normalization (what OpenPI uses for pi05)."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
stats = {ACTION: {k: v for k, v in delta_stats.items()}}
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.QUANTILES}
|
||||
|
||||
delta_step = DeltaActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step)
|
||||
|
||||
original_actions = action_chunks[0].unsqueeze(0)
|
||||
state = states[0].unsqueeze(0)
|
||||
|
||||
batch = {ACTION: original_actions, OBS_STATE: state}
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Forward: delta → quantile normalize
|
||||
t1 = delta_step(transition)
|
||||
t2 = normalizer(t1)
|
||||
|
||||
normalized = t2[TransitionKey.ACTION]
|
||||
# Most values should be in [-1, 1] with quantile normalization
|
||||
pct_in_range = (normalized.abs() < 2).float().mean()
|
||||
assert pct_in_range > 0.5, (
|
||||
f"Only {pct_in_range*100:.1f}% in [-2, 2] after quantile norm, expected >50%"
|
||||
)
|
||||
|
||||
# Reverse: unnormalize → absolute
|
||||
t3 = unnormalizer(t2)
|
||||
t4 = absolute_step(t3)
|
||||
|
||||
recovered = t4[TransitionKey.ACTION]
|
||||
torch.testing.assert_close(recovered, original_actions, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def test_state_not_modified_by_delta(dataset, action_dim):
|
||||
"""State should never be modified by the delta processor."""
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
|
||||
}
|
||||
original_state = batch[OBS_STATE].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
step = DeltaActionsProcessorStep(enabled=True)
|
||||
result = step(transition)
|
||||
|
||||
result_state = result[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
torch.testing.assert_close(result_state, original_state)
|
||||
|
||||
Reference in New Issue
Block a user