mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
chore(rollout): remove speculative action order fix
This commit is contained in:
@@ -225,66 +225,6 @@ class TestRolloutInferenceRelativeActions:
|
||||
torch.testing.assert_close(action_2, torch.from_numpy(state_1))
|
||||
assert policy.predict_calls == 1
|
||||
|
||||
def test_sync_engine_restores_action_names_for_relative_exclusions(self):
|
||||
"""Serialized processors may omit action names; sync rollout must still honor gripper exclusions."""
|
||||
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM - 1)] + ["gripper.pos"]
|
||||
relative_step, preprocessor, postprocessor = _make_relative_sync_pipelines(
|
||||
ACTION_DIM,
|
||||
action_names=action_names,
|
||||
exclude_joints=["gripper"],
|
||||
)
|
||||
relative_step.action_names = None
|
||||
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
|
||||
policy.config.action_feature_names = action_names
|
||||
dataset_features = {ACTION: {"names": action_names}}
|
||||
|
||||
assert relative_step.action_names is None
|
||||
|
||||
engine = SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset_features=dataset_features,
|
||||
ordered_action_keys=action_names,
|
||||
task="test",
|
||||
device="cpu",
|
||||
robot_type="mock",
|
||||
)
|
||||
|
||||
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
|
||||
action = engine.get_action({OBS_STATE: state.copy()})
|
||||
expected = torch.from_numpy(state.copy())
|
||||
expected[-1] = 0.0
|
||||
|
||||
torch.testing.assert_close(action, expected)
|
||||
assert relative_step.action_names == action_names
|
||||
|
||||
def test_sync_engine_does_not_remap_chunk_through_dataset_action_names(self):
|
||||
"""Postprocessed chunks are already in policy order; dataset feature order must not scramble them."""
|
||||
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
|
||||
_, preprocessor, postprocessor = _make_relative_sync_pipelines(
|
||||
ACTION_DIM,
|
||||
action_names=action_names,
|
||||
)
|
||||
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
|
||||
policy.config.action_feature_names = action_names
|
||||
|
||||
engine = SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset_features={ACTION: {"names": list(reversed(action_names))}},
|
||||
ordered_action_keys=action_names,
|
||||
task="test",
|
||||
device="cpu",
|
||||
robot_type="mock",
|
||||
)
|
||||
|
||||
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
|
||||
action = engine.get_action({OBS_STATE: state.copy()})
|
||||
|
||||
torch.testing.assert_close(action, torch.from_numpy(state))
|
||||
|
||||
def test_rtc_reanchoring_prefers_raw_cached_state(self):
|
||||
"""RTC re-anchoring must use the raw state cached before observation normalization."""
|
||||
action_dim = ACTION_DIM
|
||||
|
||||
Reference in New Issue
Block a user