fix(rl): improve action processing for discrete and continuous actions

(cherry picked from commit f887ab3f6a)
This commit is contained in:
Khalil Meftah
2026-04-26 22:47:52 +02:00
parent c278cfa026
commit f3993cbbb1
+13 -1
View File
@@ -291,7 +291,19 @@ def act_with_policy(
with policy_timer:
normalized_observation = preprocessor.process_observation(observation)
action = policy.select_action(batch=normalized_observation)
action = postprocessor.process_action(action)
# Unnormalize only the continuous part. When `num_discrete_actions` is set,
# `select_action` concatenates an argmax index in env space at the last dim;
# action stats cover the continuous dims only, so feeding the full vector to
# the unnormalizer would shape-mismatch and would also corrupt the discrete
# index by treating it as a normalized value.
if cfg.policy.num_discrete_actions is not None:
continuous_action = postprocessor.process_action(action[..., :-1])
discrete_action = action[..., -1:].to(
device=continuous_action.device, dtype=continuous_action.dtype
)
action = torch.cat([continuous_action, discrete_action], dim=-1)
else:
action = postprocessor.process_action(action)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)