diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 2fe5fa827..718cb7436 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -291,7 +291,19 @@ def act_with_policy( with policy_timer: normalized_observation = preprocessor.process_observation(observation) action = policy.select_action(batch=normalized_observation) - action = postprocessor.process_action(action) + # Unnormalize only the continuous part. When `num_discrete_actions` is set, + # `select_action` concatenates an argmax index in env space at the last dim; + # action stats cover the continuous dims only, so feeding the full vector to + # the unnormalizer would shape-mismatch and would also corrupt the discrete + # index by treating it as a normalized value. + if cfg.policy.num_discrete_actions is not None: + continuous_action = postprocessor.process_action(action[..., :-1]) + discrete_action = action[..., -1:].to( + device=continuous_action.device, dtype=continuous_action.dtype + ) + action = torch.cat([continuous_action, discrete_action], dim=-1) + else: + action = postprocessor.process_action(action) policy_fps = policy_timer.fps_last log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)