mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
fix(rl): record pre-step observation so (obs, action, next.reward) align in gym_manipulator dataset
This commit is contained in:
@@ -698,6 +698,12 @@ def control_loop(
|
|||||||
if use_gripper:
|
if use_gripper:
|
||||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
k: v.squeeze(0).cpu()
|
||||||
|
for k, v in transition[TransitionKey.OBSERVATION].items()
|
||||||
|
if isinstance(v, torch.Tensor)
|
||||||
|
}
|
||||||
|
|
||||||
transition = step_env_and_process_transition(
|
transition = step_env_and_process_transition(
|
||||||
env=env,
|
env=env,
|
||||||
transition=transition,
|
transition=transition,
|
||||||
@@ -709,16 +715,11 @@ def control_loop(
|
|||||||
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
||||||
|
|
||||||
if cfg.mode == "record":
|
if cfg.mode == "record":
|
||||||
observations = {
|
|
||||||
k: v.squeeze(0).cpu()
|
|
||||||
for k, v in transition[TransitionKey.OBSERVATION].items()
|
|
||||||
if isinstance(v, torch.Tensor)
|
|
||||||
}
|
|
||||||
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
|
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
|
||||||
"teleop_action", transition[TransitionKey.ACTION]
|
"teleop_action", transition[TransitionKey.ACTION]
|
||||||
)
|
)
|
||||||
frame = {
|
frame = {
|
||||||
**observations,
|
**observation,
|
||||||
ACTION: action_to_record.cpu(),
|
ACTION: action_to_record.cpu(),
|
||||||
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||||
DONE: np.array([terminated or truncated], dtype=bool),
|
DONE: np.array([terminated or truncated], dtype=bool),
|
||||||
|
|||||||
Reference in New Issue
Block a user