mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
fix(rl): improve action processing for discrete and continuous actions
(cherry picked from commit f887ab3f6a)
This commit is contained in:
@@ -291,6 +291,18 @@ def act_with_policy(
|
|||||||
with policy_timer:
|
with policy_timer:
|
||||||
normalized_observation = preprocessor.process_observation(observation)
|
normalized_observation = preprocessor.process_observation(observation)
|
||||||
action = policy.select_action(batch=normalized_observation)
|
action = policy.select_action(batch=normalized_observation)
|
||||||
|
# Unnormalize only the continuous part. When `num_discrete_actions` is set,
|
||||||
|
# `select_action` concatenates an argmax index in env space at the last dim;
|
||||||
|
# action stats cover the continuous dims only, so feeding the full vector to
|
||||||
|
# the unnormalizer would shape-mismatch and would also corrupt the discrete
|
||||||
|
# index by treating it as a normalized value.
|
||||||
|
if cfg.policy.num_discrete_actions is not None:
|
||||||
|
continuous_action = postprocessor.process_action(action[..., :-1])
|
||||||
|
discrete_action = action[..., -1:].to(
|
||||||
|
device=continuous_action.device, dtype=continuous_action.dtype
|
||||||
|
)
|
||||||
|
action = torch.cat([continuous_action, discrete_action], dim=-1)
|
||||||
|
else:
|
||||||
action = postprocessor.process_action(action)
|
action = postprocessor.process_action(action)
|
||||||
policy_fps = policy_timer.fps_last
|
policy_fps = policy_timer.fps_last
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user