fix(rl): record pre-step observation so (obs, action, next.reward) align in gym_manipulator dataset

This commit is contained in:
Khalil Meftah
2026-05-04 19:33:45 +02:00
parent d4a568ee6c
commit 0d60a855be
+7 -6
View File
@@ -698,6 +698,12 @@ def control_loop(
if use_gripper:
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
observation = {
k: v.squeeze(0).cpu()
for k, v in transition[TransitionKey.OBSERVATION].items()
if isinstance(v, torch.Tensor)
}
transition = step_env_and_process_transition(
env=env,
transition=transition,
@@ -709,16 +715,11 @@ def control_loop(
truncated = transition.get(TransitionKey.TRUNCATED, False)
if cfg.mode == "record":
observations = {
k: v.squeeze(0).cpu()
for k, v in transition[TransitionKey.OBSERVATION].items()
if isinstance(v, torch.Tensor)
}
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
"teleop_action", transition[TransitionKey.ACTION]
)
frame = {
**observations,
**observation,
ACTION: action_to_record.cpu(),
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
DONE: np.array([terminated or truncated], dtype=bool),