mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
Merge branch 'feat/add_relative_action_pi_models' into feat/mirror
This commit is contained in:
@@ -210,8 +210,21 @@ class DeltaActionsProcessorStep(ProcessorStep):
|
|||||||
def _build_mask(self, action_dim: int) -> list[bool]:
|
def _build_mask(self, action_dim: int) -> list[bool]:
|
||||||
if not self.exclude_joints or self.action_names is None:
|
if not self.exclude_joints or self.action_names is None:
|
||||||
return [True] * action_dim
|
return [True] * action_dim
|
||||||
exclude = set(self.exclude_joints)
|
|
||||||
return [n not in exclude for n in self.action_names]
|
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
|
||||||
|
if not exclude_tokens:
|
||||||
|
return [True] * action_dim
|
||||||
|
|
||||||
|
mask = []
|
||||||
|
for name in self.action_names[:action_dim]:
|
||||||
|
action_name = str(name).lower()
|
||||||
|
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
|
||||||
|
mask.append(not is_excluded)
|
||||||
|
|
||||||
|
if len(mask) < action_dim:
|
||||||
|
mask.extend([True] * (action_dim - len(mask)))
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
observation = transition.get(TransitionKey.OBSERVATION, {})
|
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||||||
|
|||||||
@@ -209,6 +209,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||||
|
delta_action_stats = None
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info("Creating dataset")
|
logging.info("Creating dataset")
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
@@ -218,13 +219,27 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.datasets.compute_stats import get_feature_stats
|
from lerobot.datasets.compute_stats import get_feature_stats
|
||||||
from lerobot.processor.delta_action_processor import to_delta_actions
|
from lerobot.processor.delta_action_processor import DeltaActionsProcessorStep, to_delta_actions
|
||||||
|
|
||||||
chunk_size = cfg.policy.chunk_size
|
chunk_size = cfg.policy.chunk_size
|
||||||
hf = dataset.hf_dataset
|
hf = dataset.hf_dataset
|
||||||
total_frames = len(hf)
|
total_frames = len(hf)
|
||||||
max_samples = min(100_000, total_frames - chunk_size)
|
sample_upper_bound = total_frames - chunk_size
|
||||||
indices = np.random.choice(total_frames - chunk_size, max_samples, replace=False)
|
if sample_upper_bound <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot compute delta action stats: total_frames={total_frames}, chunk_size={chunk_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
max_samples = min(100_000, sample_upper_bound)
|
||||||
|
indices = np.random.choice(sample_upper_bound, max_samples, replace=False)
|
||||||
|
|
||||||
|
action_names = dataset.meta.features.get("action", {}).get("names")
|
||||||
|
delta_mask_step = DeltaActionsProcessorStep(
|
||||||
|
enabled=True,
|
||||||
|
exclude_joints=getattr(cfg.policy, "delta_exclude_joints", []),
|
||||||
|
action_names=action_names,
|
||||||
|
)
|
||||||
|
delta_mask = delta_mask_step._build_mask(dataset.meta.features["action"]["shape"][0])
|
||||||
logging.info(
|
logging.info(
|
||||||
f"use_delta_actions is enabled — computing delta action stats "
|
f"use_delta_actions is enabled — computing delta action stats "
|
||||||
f"from {max_samples} chunk samples (chunk_size={chunk_size})"
|
f"from {max_samples} chunk samples (chunk_size={chunk_size})"
|
||||||
@@ -243,13 +258,16 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
actions = torch.tensor(np.stack([np.asarray(a) for a in chunk_data["action"]])).float()
|
actions = torch.tensor(np.stack([np.asarray(a) for a in chunk_data["action"]])).float()
|
||||||
state = torch.tensor(np.asarray(chunk_data["observation.state"][0])).float()
|
state = torch.tensor(np.asarray(chunk_data["observation.state"][0])).float()
|
||||||
|
|
||||||
mask = [True] * actions.shape[-1]
|
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), delta_mask).squeeze(0)
|
||||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
|
||||||
all_delta_actions.append(delta.numpy())
|
all_delta_actions.append(delta.numpy())
|
||||||
|
|
||||||
|
if not all_delta_actions:
|
||||||
|
raise RuntimeError("Failed to compute delta action stats: no valid chunks found.")
|
||||||
|
|
||||||
all_delta = np.concatenate(all_delta_actions, axis=0)
|
all_delta = np.concatenate(all_delta_actions, axis=0)
|
||||||
delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
|
delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
|
||||||
dataset.meta.stats["action"] = delta_stats
|
delta_action_stats = delta_stats
|
||||||
|
dataset.meta.stats["action"] = delta_action_stats
|
||||||
|
|
||||||
norm_type = "UNKNOWN"
|
norm_type = "UNKNOWN"
|
||||||
if hasattr(cfg.policy, "normalization_mapping"):
|
if hasattr(cfg.policy, "normalization_mapping"):
|
||||||
@@ -257,8 +275,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
action_norm = cfg.policy.normalization_mapping.get("ACTION", None)
|
action_norm = cfg.policy.normalization_mapping.get("ACTION", None)
|
||||||
norm_type = action_norm.value if action_norm else "UNKNOWN"
|
norm_type = action_norm.value if action_norm else "UNKNOWN"
|
||||||
|
|
||||||
|
excluded_dims = len(delta_mask) - sum(delta_mask)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Delta action stats ({len(all_delta_actions)} chunks, {len(all_delta)} values, norm={norm_type}): "
|
f"Delta action stats ({len(all_delta_actions)} chunks, {len(all_delta)} values, norm={norm_type}): "
|
||||||
|
f"delta_dims={sum(delta_mask)}/{len(delta_mask)} (excluded={excluded_dims}), "
|
||||||
f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}, "
|
f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}, "
|
||||||
f"q01={delta_stats['q01'].mean():.4f}, q99={delta_stats['q99'].mean():.4f}"
|
f"q01={delta_stats['q01'].mean():.4f}, q99={delta_stats['q99'].mean():.4f}"
|
||||||
)
|
)
|
||||||
@@ -272,6 +292,15 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
if not is_main_process:
|
if not is_main_process:
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
|
# Ensure all ranks use the exact same delta action stats.
|
||||||
|
if getattr(cfg.policy, "use_delta_actions", False):
|
||||||
|
if accelerator.num_processes > 1 and torch.distributed.is_initialized():
|
||||||
|
stats_list = [delta_action_stats]
|
||||||
|
torch.distributed.broadcast_object_list(stats_list, src=0)
|
||||||
|
delta_action_stats = stats_list[0]
|
||||||
|
if delta_action_stats is not None:
|
||||||
|
dataset.meta.stats["action"] = delta_action_stats
|
||||||
|
|
||||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||||
@@ -297,10 +326,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
# Wait for all processes to finish policy creation before continuing
|
# Wait for all processes to finish policy creation before continuing
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
processor_pretrained_path = cfg.policy.pretrained_path
|
||||||
|
if (
|
||||||
|
getattr(cfg.policy, "use_delta_actions", False)
|
||||||
|
and processor_pretrained_path is not None
|
||||||
|
and not cfg.resume
|
||||||
|
):
|
||||||
|
logging.warning(
|
||||||
|
"use_delta_actions=true with pretrained processors can skip delta transforms if "
|
||||||
|
"the checkpoint processors do not define them. Building processors from current policy config."
|
||||||
|
)
|
||||||
|
processor_pretrained_path = None
|
||||||
|
|
||||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
postprocessor_kwargs = {}
|
postprocessor_kwargs = {}
|
||||||
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
|
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||||
# Only provide dataset_stats when not resuming from saved processor state
|
# Only provide dataset_stats when not resuming from saved processor state
|
||||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||||
|
|
||||||
@@ -308,7 +349,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
if cfg.policy.type == "sarm":
|
if cfg.policy.type == "sarm":
|
||||||
processor_kwargs["dataset_meta"] = dataset.meta
|
processor_kwargs["dataset_meta"] = dataset.meta
|
||||||
|
|
||||||
if cfg.policy.pretrained_path is not None:
|
if processor_pretrained_path is not None:
|
||||||
processor_kwargs["preprocessor_overrides"] = {
|
processor_kwargs["preprocessor_overrides"] = {
|
||||||
"device_processor": {"device": device.type},
|
"device_processor": {"device": device.type},
|
||||||
"normalizer_processor": {
|
"normalizer_processor": {
|
||||||
@@ -330,7 +371,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
policy_cfg=cfg.policy,
|
policy_cfg=cfg.policy,
|
||||||
pretrained_path=cfg.policy.pretrained_path,
|
pretrained_path=processor_pretrained_path,
|
||||||
**processor_kwargs,
|
**processor_kwargs,
|
||||||
**postprocessor_kwargs,
|
**postprocessor_kwargs,
|
||||||
)
|
)
|
||||||
@@ -448,7 +489,36 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
for _ in range(step, cfg.steps):
|
for _ in range(step, cfg.steps):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
|
|
||||||
|
# Debug logging for first few steps and periodically
|
||||||
|
if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)):
|
||||||
|
action = batch.get("action")
|
||||||
|
state = batch.get("observation.state")
|
||||||
|
if action is not None and state is not None:
|
||||||
|
logging.info(
|
||||||
|
f"[DEBUG step={step}] PRE-PROCESSOR — "
|
||||||
|
f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, "
|
||||||
|
f"min={action.min():.4f}, max={action.max():.4f} | "
|
||||||
|
f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
batch = preprocessor(batch)
|
batch = preprocessor(batch)
|
||||||
|
|
||||||
|
if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)):
|
||||||
|
action = batch.get("action")
|
||||||
|
state = batch.get("observation.state")
|
||||||
|
if action is not None:
|
||||||
|
logging.info(
|
||||||
|
f"[DEBUG step={step}] POST-PROCESSOR — "
|
||||||
|
f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, "
|
||||||
|
f"min={action.min():.4f}, max={action.max():.4f}"
|
||||||
|
)
|
||||||
|
if state is not None:
|
||||||
|
logging.info(
|
||||||
|
f"[DEBUG step={step}] POST-PROCESSOR — "
|
||||||
|
f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}, std={state.std():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||||
|
|
||||||
train_tracker, output_dict = update_policy(
|
train_tracker, output_dict = update_policy(
|
||||||
|
|||||||
@@ -1,124 +1,344 @@
|
|||||||
"""Tests for delta action transforms using a local dummy dataset."""
|
"""Tests for delta action transforms — full pipeline validation.
|
||||||
|
|
||||||
|
Tests the complete flow matching OpenPI:
|
||||||
|
raw actions → DeltaActions → Normalize(delta_stats) → model → Unnormalize → AbsoluteActions
|
||||||
|
|
||||||
|
Uses real dataset: lerobot-data-collection/dagger_final_1_21
|
||||||
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.datasets.compute_stats import get_feature_stats
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.processor import TransitionKey, batch_to_transition
|
from lerobot.processor import TransitionKey, batch_to_transition
|
||||||
from lerobot.processor.delta_action_processor import (
|
from lerobot.processor.delta_action_processor import (
|
||||||
|
AbsoluteActionsProcessorStep,
|
||||||
DeltaActionsProcessorStep,
|
DeltaActionsProcessorStep,
|
||||||
to_absolute_actions,
|
to_absolute_actions,
|
||||||
to_delta_actions,
|
to_delta_actions,
|
||||||
)
|
)
|
||||||
|
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||||
|
|
||||||
ACTION_DIM = 14
|
CHUNK_SIZE = 10
|
||||||
STATE_DIM = 14
|
REPO_ID = "lerobot-data-collection/dagger_final_1_21"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(scope="module")
|
||||||
def dataset(tmp_path, empty_lerobot_dataset_factory):
|
def dataset():
|
||||||
features = {
|
return LeRobotDataset(REPO_ID, episodes=[0])
|
||||||
"action": {"dtype": "float32", "shape": (ACTION_DIM,), "names": None},
|
|
||||||
"observation.state": {"dtype": "float32", "shape": (STATE_DIM,), "names": None},
|
|
||||||
}
|
|
||||||
ds = empty_lerobot_dataset_factory(root=tmp_path / "delta_test", features=features)
|
|
||||||
for ep in range(2):
|
|
||||||
for _ in range(5):
|
|
||||||
ds.add_frame(
|
|
||||||
{
|
|
||||||
"action": np.random.randn(ACTION_DIM).astype(np.float32),
|
|
||||||
"observation.state": np.random.randn(STATE_DIM).astype(np.float32),
|
|
||||||
"task": f"task_{ep}",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
ds.save_episode()
|
|
||||||
ds.finalize()
|
|
||||||
return ds
|
|
||||||
|
|
||||||
|
|
||||||
def _collate(dataset, indices):
|
@pytest.fixture(scope="module")
|
||||||
items = [dataset[i] for i in indices]
|
def action_dim(dataset):
|
||||||
batch = {}
|
return dataset.meta.features["action"]["shape"][0]
|
||||||
for key in items[0]:
|
|
||||||
vals = [item[key] for item in items]
|
|
||||||
if isinstance(vals[0], torch.Tensor):
|
|
||||||
batch[key] = torch.stack(vals)
|
|
||||||
else:
|
|
||||||
batch[key] = vals
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def test_roundtrip_3d(dataset):
|
def _build_action_chunks(dataset, chunk_size, max_chunks=50):
|
||||||
"""Delta then absolute on real data should recover original actions."""
|
"""Build action chunks from hf_dataset, like the training script does."""
|
||||||
batch = _collate(dataset, range(4))
|
hf = dataset.hf_dataset
|
||||||
actions = batch[ACTION].unsqueeze(1).expand(-1, 10, -1).clone()
|
total = len(hf)
|
||||||
state = batch[OBS_STATE]
|
all_ep = torch.tensor([int(hf[i]["episode_index"]) for i in range(total)])
|
||||||
mask = [True] * actions.shape[-1]
|
chunks, states = [], []
|
||||||
|
for i in range(total - chunk_size + 1):
|
||||||
|
if all_ep[i] != all_ep[i + chunk_size - 1]:
|
||||||
|
continue
|
||||||
|
chunk_actions = torch.stack([hf[i + k]["action"] for k in range(chunk_size)]).float()
|
||||||
|
state = hf[i]["observation.state"].float()
|
||||||
|
chunks.append(chunk_actions)
|
||||||
|
states.append(state)
|
||||||
|
if len(chunks) >= max_chunks:
|
||||||
|
break
|
||||||
|
assert len(chunks) > 0, f"No valid chunks found. total={total}, ep_indices={all_ep.tolist()}"
|
||||||
|
return torch.stack(chunks), torch.stack(states)
|
||||||
|
|
||||||
delta = to_delta_actions(actions, state, mask)
|
|
||||||
recovered = to_absolute_actions(delta, state, mask)
|
def _compute_delta_chunk_stats(action_chunks, states, mask):
|
||||||
|
all_deltas = []
|
||||||
|
for actions, state in zip(action_chunks, states):
|
||||||
|
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||||
|
all_deltas.append(delta.numpy())
|
||||||
|
all_delta = np.concatenate(all_deltas, axis=0)
|
||||||
|
return get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Basic roundtrip tests ---
|
||||||
|
|
||||||
|
def test_roundtrip_3d(action_dim):
|
||||||
|
actions = torch.randn(4, CHUNK_SIZE, action_dim)
|
||||||
|
state = torch.randn(4, action_dim)
|
||||||
|
mask = [True] * action_dim
|
||||||
|
recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask)
|
||||||
torch.testing.assert_close(recovered, actions)
|
torch.testing.assert_close(recovered, actions)
|
||||||
|
|
||||||
|
|
||||||
def test_roundtrip_2d(dataset):
|
def test_roundtrip_2d(action_dim):
|
||||||
"""Works with (B, action_dim) shaped actions too."""
|
actions = torch.randn(4, action_dim)
|
||||||
batch = _collate(dataset, range(4))
|
state = torch.randn(4, action_dim)
|
||||||
actions = batch[ACTION]
|
mask = [True] * action_dim
|
||||||
state = batch[OBS_STATE]
|
recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask)
|
||||||
mask = [True] * actions.shape[-1]
|
|
||||||
|
|
||||||
delta = to_delta_actions(actions, state, mask)
|
|
||||||
recovered = to_absolute_actions(delta, state, mask)
|
|
||||||
torch.testing.assert_close(recovered, actions)
|
torch.testing.assert_close(recovered, actions)
|
||||||
|
|
||||||
|
|
||||||
def test_delta_changes_all_dims(dataset):
|
def test_no_mutation(action_dim):
|
||||||
"""All dims should change when mask is all True."""
|
actions = torch.randn(2, CHUNK_SIZE, action_dim)
|
||||||
batch = _collate(dataset, range(4))
|
|
||||||
actions = batch[ACTION].unsqueeze(1)
|
|
||||||
state = batch[OBS_STATE]
|
|
||||||
mask = [True] * actions.shape[-1]
|
|
||||||
|
|
||||||
delta = to_delta_actions(actions, state, mask)
|
|
||||||
assert (delta - actions).abs().sum() > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_mutation(dataset):
|
|
||||||
"""Original tensors should not be modified."""
|
|
||||||
batch = _collate(dataset, range(2))
|
|
||||||
actions = batch[ACTION].unsqueeze(1)
|
|
||||||
original = actions.clone()
|
original = actions.clone()
|
||||||
state = batch[OBS_STATE]
|
state = torch.randn(2, action_dim)
|
||||||
mask = [True] * actions.shape[-1]
|
to_delta_actions(actions, state, [True] * action_dim)
|
||||||
|
|
||||||
to_delta_actions(actions, state, mask)
|
|
||||||
torch.testing.assert_close(actions, original)
|
torch.testing.assert_close(actions, original)
|
||||||
|
|
||||||
|
|
||||||
def test_processor_step_roundtrip(dataset):
|
def test_exclude_joints_supports_partial_name_matching():
|
||||||
|
names = [
|
||||||
|
"right_joint_1.pos",
|
||||||
|
"right_gripper.pos",
|
||||||
|
"left_joint_1.pos",
|
||||||
|
"left_gripper.pos",
|
||||||
|
]
|
||||||
|
step = DeltaActionsProcessorStep(enabled=True, exclude_joints=["gripper"], action_names=names)
|
||||||
|
assert step._build_mask(len(names)) == [True, False, True, False]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Chunk-level delta stats test ---
|
||||||
|
|
||||||
|
def test_chunk_stats_have_larger_std_than_frame_stats(dataset, action_dim):
|
||||||
|
"""Chunk-level delta stats should have larger std than per-frame delta stats."""
|
||||||
|
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||||
|
mask = [True] * action_dim
|
||||||
|
|
||||||
|
chunk_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||||
|
|
||||||
|
# Per-frame stats
|
||||||
|
hf = dataset.hf_dataset
|
||||||
|
n = min(500, len(hf))
|
||||||
|
frame_actions = torch.stack([hf[i]["action"] for i in range(n)]).float()
|
||||||
|
frame_states = torch.stack([hf[i]["observation.state"] for i in range(n)]).float()
|
||||||
|
frame_deltas = to_delta_actions(frame_actions, frame_states, mask).numpy()
|
||||||
|
frame_stats = get_feature_stats(frame_deltas, axis=0, keepdims=frame_deltas.ndim == 1)
|
||||||
|
|
||||||
|
assert chunk_stats["std"].mean() >= frame_stats["std"].mean(), (
|
||||||
|
f"Chunk std ({chunk_stats['std'].mean():.4f}) should be >= "
|
||||||
|
f"frame std ({frame_stats['std'].mean():.4f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Full pipeline roundtrip: delta → normalize → unnormalize → absolute ---
|
||||||
|
|
||||||
|
def test_full_pipeline_roundtrip(dataset, action_dim):
|
||||||
|
"""Test the complete OpenPI pipeline: delta → normalize → unnormalize → absolute."""
|
||||||
|
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||||
|
mask = [True] * action_dim
|
||||||
|
|
||||||
|
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||||
|
stats = {ACTION: {k: v for k, v in delta_stats.items()}}
|
||||||
|
|
||||||
|
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
||||||
|
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||||
|
|
||||||
|
delta_step = DeltaActionsProcessorStep(enabled=True)
|
||||||
|
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step)
|
||||||
|
|
||||||
|
original_actions = action_chunks[0].unsqueeze(0)
|
||||||
|
state = states[0].unsqueeze(0)
|
||||||
|
|
||||||
|
batch = {ACTION: original_actions, OBS_STATE: state}
|
||||||
|
transition = batch_to_transition(batch)
|
||||||
|
|
||||||
|
# Forward: delta → normalize
|
||||||
|
t1 = delta_step(transition)
|
||||||
|
t2 = normalizer(t1)
|
||||||
|
|
||||||
|
normalized_action = t2[TransitionKey.ACTION]
|
||||||
|
assert normalized_action.abs().mean() < 10, (
|
||||||
|
f"Normalized actions should be in reasonable range, got mean abs {normalized_action.abs().mean():.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reverse: unnormalize → absolute
|
||||||
|
t3 = unnormalizer(t2)
|
||||||
|
t4 = absolute_step(t3)
|
||||||
|
|
||||||
|
recovered_actions = t4[TransitionKey.ACTION]
|
||||||
|
torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalized_delta_values_are_reasonable(dataset, action_dim):
|
||||||
|
"""With correct chunk stats, normalized delta actions should be in a reasonable range."""
|
||||||
|
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||||
|
mask = [True] * action_dim
|
||||||
|
|
||||||
|
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||||
|
mean = torch.tensor(delta_stats["mean"]).float()
|
||||||
|
std = torch.tensor(delta_stats["std"]).float()
|
||||||
|
|
||||||
|
all_normalized = []
|
||||||
|
for actions, state in zip(action_chunks, states):
|
||||||
|
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||||
|
normalized = (delta - mean) / (std + 1e-6)
|
||||||
|
all_normalized.append(normalized)
|
||||||
|
|
||||||
|
all_normalized = torch.cat(all_normalized, dim=0)
|
||||||
|
|
||||||
|
pct_in_range = (all_normalized.abs() < 5).float().mean()
|
||||||
|
assert pct_in_range > 0.9, (
|
||||||
|
f"Only {pct_in_range*100:.1f}% of normalized values in [-5, 5], expected >90%"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all_normalized.mean().abs() < 1.0, (
|
||||||
|
f"Mean of normalized deltas is {all_normalized.mean():.2f}, expected near 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_processor_step_roundtrip(dataset, action_dim):
|
||||||
"""DeltaActionsProcessorStep applies delta; to_absolute_actions recovers original."""
|
"""DeltaActionsProcessorStep applies delta; to_absolute_actions recovers original."""
|
||||||
batch = _collate(dataset, range(4))
|
hf = dataset.hf_dataset
|
||||||
|
batch = {
|
||||||
|
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
|
||||||
|
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
|
||||||
|
}
|
||||||
original_actions = batch[ACTION].clone()
|
original_actions = batch[ACTION].clone()
|
||||||
transition = batch_to_transition(batch)
|
transition = batch_to_transition(batch)
|
||||||
|
|
||||||
step = DeltaActionsProcessorStep(enabled=True)
|
step = DeltaActionsProcessorStep(enabled=True)
|
||||||
delta_transition = step(transition)
|
delta_transition = step(transition)
|
||||||
|
assert not torch.allclose(delta_transition[TransitionKey.ACTION], original_actions)
|
||||||
delta_actions = delta_transition[TransitionKey.ACTION]
|
|
||||||
assert not torch.allclose(delta_actions, original_actions)
|
|
||||||
|
|
||||||
state = transition[TransitionKey.OBSERVATION][OBS_STATE]
|
state = transition[TransitionKey.OBSERVATION][OBS_STATE]
|
||||||
mask = [True] * original_actions.shape[-1]
|
mask = [True] * action_dim
|
||||||
recovered = to_absolute_actions(delta_actions, state, mask)
|
recovered = to_absolute_actions(delta_transition[TransitionKey.ACTION], state, mask)
|
||||||
torch.testing.assert_close(recovered, original_actions)
|
torch.testing.assert_close(recovered, original_actions)
|
||||||
|
|
||||||
|
|
||||||
def test_processor_step_disabled_is_noop(dataset):
|
def test_processor_step_disabled_is_noop(dataset, action_dim):
|
||||||
"""enabled=False should be a no-op."""
|
"""enabled=False should be a no-op."""
|
||||||
batch = _collate(dataset, range(2))
|
hf = dataset.hf_dataset
|
||||||
|
batch = {
|
||||||
|
ACTION: torch.stack([hf[i]["action"] for i in range(2)]),
|
||||||
|
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(2)]),
|
||||||
|
}
|
||||||
original = batch[ACTION].clone()
|
original = batch[ACTION].clone()
|
||||||
transition = batch_to_transition(batch)
|
transition = batch_to_transition(batch)
|
||||||
|
|
||||||
result = DeltaActionsProcessorStep(enabled=False)(transition)
|
result = DeltaActionsProcessorStep(enabled=False)(transition)
|
||||||
torch.testing.assert_close(result[TransitionKey.ACTION], original)
|
torch.testing.assert_close(result[TransitionKey.ACTION], original)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Training batch shape validation ---
|
||||||
|
|
||||||
|
def test_delta_with_action_chunks(dataset, action_dim):
|
||||||
|
"""Verify delta works correctly with (B, chunk_size, action_dim) shaped actions."""
|
||||||
|
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||||
|
|
||||||
|
# Simulate a training batch: actions=(B, chunk_size, action_dim), state=(B, state_dim)
|
||||||
|
batch_actions = action_chunks[:4] # (4, chunk_size, action_dim)
|
||||||
|
batch_states = states[:4] # (4, state_dim)
|
||||||
|
|
||||||
|
mask = [True] * action_dim
|
||||||
|
delta = to_delta_actions(batch_actions, batch_states, mask)
|
||||||
|
|
||||||
|
# First action in each chunk should be close to zero (action[t] - state[t] ≈ small)
|
||||||
|
first_deltas = delta[:, 0, :] # (B, action_dim)
|
||||||
|
assert first_deltas.abs().mean() < delta.abs().mean(), (
|
||||||
|
f"First action in chunk should have smaller delta than average. "
|
||||||
|
f"First: {first_deltas.abs().mean():.4f}, Average: {delta.abs().mean():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Later actions should have larger deltas
|
||||||
|
last_deltas = delta[:, -1, :] # (B, action_dim)
|
||||||
|
assert last_deltas.abs().mean() >= first_deltas.abs().mean(), (
|
||||||
|
f"Last action in chunk should have >= delta than first. "
|
||||||
|
f"Last: {last_deltas.abs().mean():.4f}, First: {first_deltas.abs().mean():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Roundtrip
|
||||||
|
recovered = to_absolute_actions(delta, batch_states, mask)
|
||||||
|
torch.testing.assert_close(recovered, batch_actions)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delta_stats_match_actual_data_distribution(dataset, action_dim):
|
||||||
|
"""Verify computed stats match the actual delta distribution."""
|
||||||
|
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||||
|
mask = [True] * action_dim
|
||||||
|
|
||||||
|
# Compute stats like the training script does
|
||||||
|
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||||
|
|
||||||
|
# Also compute directly
|
||||||
|
all_deltas = []
|
||||||
|
for actions, state in zip(action_chunks, states):
|
||||||
|
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||||
|
all_deltas.append(delta)
|
||||||
|
all_deltas_tensor = torch.cat(all_deltas, dim=0)
|
||||||
|
|
||||||
|
# Compare mean
|
||||||
|
actual_mean = all_deltas_tensor.mean(dim=0).numpy()
|
||||||
|
np.testing.assert_allclose(delta_stats["mean"], actual_mean, atol=0.01)
|
||||||
|
|
||||||
|
# Compare std
|
||||||
|
actual_std = all_deltas_tensor.std(dim=0).numpy()
|
||||||
|
np.testing.assert_allclose(delta_stats["std"], actual_std, atol=0.1)
|
||||||
|
|
||||||
|
# Verify q01 < mean < q99
|
||||||
|
assert (delta_stats["q01"] < delta_stats["mean"]).all(), "q01 should be < mean"
|
||||||
|
assert (delta_stats["mean"] < delta_stats["q99"]).all(), "mean should be < q99"
|
||||||
|
|
||||||
|
|
||||||
|
def test_quantile_normalization_roundtrip(dataset, action_dim):
|
||||||
|
"""Full roundtrip with QUANTILES normalization (what OpenPI uses for pi05)."""
|
||||||
|
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||||
|
mask = [True] * action_dim
|
||||||
|
|
||||||
|
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||||
|
stats = {ACTION: {k: v for k, v in delta_stats.items()}}
|
||||||
|
|
||||||
|
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
||||||
|
norm_map = {FeatureType.ACTION: NormalizationMode.QUANTILES}
|
||||||
|
|
||||||
|
delta_step = DeltaActionsProcessorStep(enabled=True)
|
||||||
|
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||||
|
absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step)
|
||||||
|
|
||||||
|
original_actions = action_chunks[0].unsqueeze(0)
|
||||||
|
state = states[0].unsqueeze(0)
|
||||||
|
|
||||||
|
batch = {ACTION: original_actions, OBS_STATE: state}
|
||||||
|
transition = batch_to_transition(batch)
|
||||||
|
|
||||||
|
# Forward: delta → quantile normalize
|
||||||
|
t1 = delta_step(transition)
|
||||||
|
t2 = normalizer(t1)
|
||||||
|
|
||||||
|
normalized = t2[TransitionKey.ACTION]
|
||||||
|
# Most values should be in [-1, 1] with quantile normalization
|
||||||
|
pct_in_range = (normalized.abs() < 2).float().mean()
|
||||||
|
assert pct_in_range > 0.5, (
|
||||||
|
f"Only {pct_in_range*100:.1f}% in [-2, 2] after quantile norm, expected >50%"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reverse: unnormalize → absolute
|
||||||
|
t3 = unnormalizer(t2)
|
||||||
|
t4 = absolute_step(t3)
|
||||||
|
|
||||||
|
recovered = t4[TransitionKey.ACTION]
|
||||||
|
torch.testing.assert_close(recovered, original_actions, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_not_modified_by_delta(dataset, action_dim):
|
||||||
|
"""State should never be modified by the delta processor."""
|
||||||
|
hf = dataset.hf_dataset
|
||||||
|
batch = {
|
||||||
|
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
|
||||||
|
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
|
||||||
|
}
|
||||||
|
original_state = batch[OBS_STATE].clone()
|
||||||
|
transition = batch_to_transition(batch)
|
||||||
|
|
||||||
|
step = DeltaActionsProcessorStep(enabled=True)
|
||||||
|
result = step(transition)
|
||||||
|
|
||||||
|
result_state = result[TransitionKey.OBSERVATION][OBS_STATE]
|
||||||
|
torch.testing.assert_close(result_state, original_state)
|
||||||
|
|||||||
Reference in New Issue
Block a user