mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
fix(rl): improve action processing for discrete and continuous actions
(cherry picked from commit f887ab3f6a)
This commit is contained in:
+13
-1
@@ -291,7 +291,19 @@ def act_with_policy(
|
||||
with policy_timer:
|
||||
normalized_observation = preprocessor.process_observation(observation)
|
||||
action = policy.select_action(batch=normalized_observation)
|
||||
action = postprocessor.process_action(action)
|
||||
# Unnormalize only the continuous part. When `num_discrete_actions` is set,
|
||||
# `select_action` concatenates an argmax index in env space at the last dim;
|
||||
# action stats cover the continuous dims only, so feeding the full vector to
|
||||
# the unnormalizer would shape-mismatch and would also corrupt the discrete
|
||||
# index by treating it as a normalized value.
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
continuous_action = postprocessor.process_action(action[..., :-1])
|
||||
discrete_action = action[..., -1:].to(
|
||||
device=continuous_action.device, dtype=continuous_action.dtype
|
||||
)
|
||||
action = torch.cat([continuous_action, discrete_action], dim=-1)
|
||||
else:
|
||||
action = postprocessor.process_action(action)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
Reference in New Issue
Block a user