add processor to main

This commit is contained in:
Khalil Meftah
2026-04-24 17:06:57 +02:00
parent 580d818aa9
commit 6495bb9706
3 changed files with 30 additions and 8 deletions
@@ -134,6 +134,15 @@ class _NormalizationMixin:
if self.dtype is None:
self.dtype = torch.float32
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
def _reshape_visual_stats(self) -> None:
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
for key, feature in self.features.items():
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
for stat_name, stat_tensor in self._tensor_stats[key].items():
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
@@ -152,6 +161,7 @@ class _NormalizationMixin:
if dtype is not None:
self.dtype = dtype
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
return self
def state_dict(self) -> dict[str, Tensor]:
@@ -201,6 +211,7 @@ class _NormalizationMixin:
# Don't load from state_dict, keep the explicitly provided stats
# But ensure _tensor_stats is properly initialized
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
self._reshape_visual_stats()
return
# Normal behavior: load stats from state_dict
@@ -211,6 +222,7 @@ class _NormalizationMixin:
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
dtype=torch.float32, device=self.device
)
self._reshape_visual_stats()
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
# and other functions that rely on self.stats
+8 -3
View File
@@ -60,7 +60,7 @@ from torch.multiprocessing import Event, Queue
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies import make_policy
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
@@ -261,6 +261,11 @@ def act_with_policy(
policy = policy.eval()
assert isinstance(policy, nn.Module)
preprocessor, _postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
)
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
@@ -291,8 +296,8 @@ def act_with_policy(
# Time policy inference and check if it meets FPS requirement
with policy_timer:
# Extract observation from transition for policy
action = policy.select_action(batch=observation)
normalized_observation = preprocessor.process_observation(observation)
action = policy.select_action(batch=normalized_observation)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
+10 -5
View File
@@ -70,7 +70,7 @@ from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets import LeRobotDataset, make_dataset
from lerobot.policies import make_policy
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
@@ -317,6 +317,11 @@ def add_actor_information_and_train(
policy.train()
preprocessor, _postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
)
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time()
@@ -405,8 +410,8 @@ def add_actor_information_and_train(
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
observations = preprocessor.process_observation(batch["state"])
next_observations = preprocessor.process_observation(batch["next_state"])
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
@@ -463,8 +468,8 @@ def add_actor_information_and_train(
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
observations = preprocessor.process_observation(batch["state"])
next_observations = preprocessor.process_observation(batch["next_state"])
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)