mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(rl): record pre-step observation so (obs, action, next.reward) align in gym_manipulator dataset
This commit is contained in:
@@ -698,6 +698,12 @@ def control_loop(
|
||||
if use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||
|
||||
observation = {
|
||||
k: v.squeeze(0).cpu()
|
||||
for k, v in transition[TransitionKey.OBSERVATION].items()
|
||||
if isinstance(v, torch.Tensor)
|
||||
}
|
||||
|
||||
transition = step_env_and_process_transition(
|
||||
env=env,
|
||||
transition=transition,
|
||||
@@ -709,16 +715,11 @@ def control_loop(
|
||||
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
if cfg.mode == "record":
|
||||
observations = {
|
||||
k: v.squeeze(0).cpu()
|
||||
for k, v in transition[TransitionKey.OBSERVATION].items()
|
||||
if isinstance(v, torch.Tensor)
|
||||
}
|
||||
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
|
||||
"teleop_action", transition[TransitionKey.ACTION]
|
||||
)
|
||||
frame = {
|
||||
**observations,
|
||||
**observation,
|
||||
ACTION: action_to_record.cpu(),
|
||||
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||
DONE: np.array([terminated or truncated], dtype=bool),
|
||||
|
||||
Reference in New Issue
Block a user