diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 7516c7b47..89b6f1c18 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -134,6 +134,15 @@ class _NormalizationMixin: if self.dtype is None: self.dtype = torch.float32 self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) + self._reshape_visual_stats() + + def _reshape_visual_stats(self) -> None: + """Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting.""" + for key, feature in self.features.items(): + if feature.type == FeatureType.VISUAL and key in self._tensor_stats: + for stat_name, stat_tensor in self._tensor_stats[key].items(): + if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1: + self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1) def to( self, device: torch.device | str | None = None, dtype: torch.dtype | None = None @@ -152,6 +161,7 @@ class _NormalizationMixin: if dtype is not None: self.dtype = dtype self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) + self._reshape_visual_stats() return self def state_dict(self) -> dict[str, Tensor]: @@ -201,6 +211,7 @@ class _NormalizationMixin: # Don't load from state_dict, keep the explicitly provided stats # But ensure _tensor_stats is properly initialized self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment] + self._reshape_visual_stats() return # Normal behavior: load stats from state_dict @@ -211,6 +222,7 @@ class _NormalizationMixin: self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to( dtype=torch.float32, device=self.device ) + self._reshape_visual_stats() # Reconstruct the original stats dict from tensor stats for compatibility with to() method # and other functions that rely on self.stats diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 588adffac..6167456dc 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -60,7 +60,7 @@ from torch.multiprocessing import Event, Queue from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig -from lerobot.policies import make_policy +from lerobot.policies import make_policy, make_pre_post_processors from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.robots import so_follower # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401 @@ -261,6 +261,11 @@ def act_with_policy( policy = policy.eval() assert isinstance(policy, nn.Module) + preprocessor, _postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + dataset_stats=cfg.policy.dataset_stats, + ) + obs, info = online_env.reset() env_processor.reset() action_processor.reset() @@ -291,8 +296,8 @@ def act_with_policy( # Time policy inference and check if it meets FPS requirement with policy_timer: - # Extract observation from transition for policy - action = policy.select_action(batch=observation) + normalized_observation = preprocessor.process_observation(observation) + action = policy.select_action(batch=normalized_observation) policy_fps = policy_timer.fps_last log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index d1207421b..af910d314 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -70,7 +70,7 @@ from lerobot.common.wandb_utils import WandBLogger from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.datasets import LeRobotDataset, make_dataset -from lerobot.policies import make_policy +from lerobot.policies import make_policy, make_pre_post_processors from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.robots import so_follower # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401 @@ -317,6 +317,11 @@ def add_actor_information_and_train( policy.train() + preprocessor, _postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + dataset_stats=cfg.policy.dataset_stats, + ) + push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) last_time_policy_pushed = time.time() @@ -405,8 +410,8 @@ def add_actor_information_and_train( actions = batch[ACTION] rewards = batch["reward"] - observations = batch["state"] - next_observations = batch["next_state"] + observations = preprocessor.process_observation(batch["state"]) + next_observations = preprocessor.process_observation(batch["next_state"]) done = batch["done"] check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) @@ -463,8 +468,8 @@ def add_actor_information_and_train( actions = batch[ACTION] rewards = batch["reward"] - observations = batch["state"] - next_observations = batch["next_state"] + observations = preprocessor.process_observation(batch["state"]) + next_observations = preprocessor.process_observation(batch["next_state"]) done = batch["done"] check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)