feat(record): Integrate RobotProcessor into recording loop and update policy handling

- Added support for RobotProcessor in the record_loop function to enhance data processing capabilities.
- Updated the logic to reset both policy and processor when provided, ensuring proper state management.
- Modified action prediction to utilize the processor, improving the overall functionality of the recording process.
- Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities.
This commit is contained in:
AdilZouitine
2025-07-17 13:10:15 +02:00
committed by Steven Palma
parent 670a278cbc
commit 4b24f94225
6 changed files with 36 additions and 14 deletions
+1 -2
View File
@@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE from lerobot.constants import ACTION, OBS_STATE
from lerobot.optim.optimizers import OptimizerConfig from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig from lerobot.optim.schedulers import LRSchedulerConfig
@@ -53,7 +53,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
""" """
n_obs_steps: int = 1 n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict) input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict) output_features: dict[str, PolicyFeature] = field(default_factory=dict)
+4 -2
View File
@@ -35,7 +35,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor.pipeline import EnvTransition, RobotProcessor, TransitionIndex from lerobot.processor.pipeline import EnvTransition, IdentityProcessor, RobotProcessor, TransitionIndex
def get_policy_class(name: str) -> PreTrainedPolicy: def get_policy_class(name: str) -> PreTrainedPolicy:
@@ -128,7 +128,9 @@ def make_processor(
if pretrained_path: if pretrained_path:
# Load a pretrained processor # Load a pretrained processor
# TODO(azouitine): Handle this case. # TODO(azouitine): Handle this case.
raise NotImplementedError("Loading a pretrained processor is not implemented.") return RobotProcessor.from_pretrained(source=pretrained_path), RobotProcessor(
steps=[IdentityProcessor()], name="post_processor"
)
# Create a new processor based on policy type # Create a new processor based on policy type
if policy_cfg.type == "tdmpc": if policy_cfg.type == "tdmpc":
+19 -8
View File
@@ -74,8 +74,9 @@ from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy from lerobot.policies.factory import make_policy, make_processor
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor
from lerobot.robots import ( # noqa: F401 from lerobot.robots import ( # noqa: F401
Robot, Robot,
RobotConfig, RobotConfig,
@@ -195,6 +196,7 @@ def record_loop(
dataset: LeRobotDataset | None = None, dataset: LeRobotDataset | None = None,
teleop: Teleoperator | list[Teleoperator] | None = None, teleop: Teleoperator | list[Teleoperator] | None = None,
policy: PreTrainedPolicy | None = None, policy: PreTrainedPolicy | None = None,
processor: RobotProcessor | None = None,
control_time_s: int | None = None, control_time_s: int | None = None,
single_task: str | None = None, single_task: str | None = None,
display_data: bool = False, display_data: bool = False,
@@ -219,9 +221,10 @@ def record_loop(
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot." "For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
) )
# if policy is given it needs cleaning up # Reset policy and processor if they are provided
if policy is not None: if policy is not None or processor is not None:
policy.reset() policy.reset()
processor.reset()
timestamp = 0 timestamp = 0
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
@@ -237,12 +240,13 @@ def record_loop(
if policy is not None or dataset is not None: if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
if policy is not None: if policy is not None or processor is not None:
action_values = predict_action( action_values = predict_action(
observation_frame, observation=observation_frame,
policy, policy=policy,
get_safe_torch_device(policy.config.device), device=get_safe_torch_device(policy.config.device),
policy.config.use_amp, processor=processor,
use_amp=policy.config.use_amp,
task=single_task, task=single_task,
robot_type=robot.robot_type, robot_type=robot.robot_type,
) )
@@ -328,6 +332,13 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
# Load pretrained policy # Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
processor = None
if cfg.policy is not None:
processor, _ = make_processor(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=dataset.meta.stats,
)
robot.connect() robot.connect()
if teleop is not None: if teleop is not None:
+1 -1
View File
@@ -244,7 +244,7 @@ def train(cfg: TrainPipelineConfig):
if cfg.save_checkpoint and is_saving_step: if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}") logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler) save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor)
update_last_checkpoint(checkpoint_dir) update_last_checkpoint(checkpoint_dir)
if wandb_logger: if wandb_logger:
wandb_logger.log_policy(checkpoint_dir) wandb_logger.log_policy(checkpoint_dir)
+4
View File
@@ -31,6 +31,7 @@ from termcolor import colored
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_FEATURES from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor
from lerobot.robots import Robot from lerobot.robots import Robot
@@ -101,6 +102,7 @@ def predict_action(
observation: dict[str, np.ndarray], observation: dict[str, np.ndarray],
policy: PreTrainedPolicy, policy: PreTrainedPolicy,
device: torch.device, device: torch.device,
processor: RobotProcessor,
use_amp: bool, use_amp: bool,
task: str | None = None, task: str | None = None,
robot_type: str | None = None, robot_type: str | None = None,
@@ -122,6 +124,8 @@ def predict_action(
observation["task"] = task if task else "" observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else "" observation["robot_type"] = robot_type if robot_type else ""
observation = processor(observation)
# Compute the next action with the policy # Compute the next action with the policy
# based on the current observation # based on the current observation
action = policy.select_action(observation) action = policy.select_action(observation)
+7 -1
View File
@@ -74,6 +74,7 @@ def save_checkpoint(
policy: PreTrainedPolicy, policy: PreTrainedPolicy,
optimizer: Optimizer, optimizer: Optimizer,
scheduler: LRScheduler | None = None, scheduler: LRScheduler | None = None,
preprocessor=None,
) -> None: ) -> None:
"""This function creates the following directory structure: """This function creates the following directory structure:
@@ -81,7 +82,9 @@ def save_checkpoint(
pretrained_model/ pretrained_model/
config.json # policy config config.json # policy config
model.safetensors # policy weights model.safetensors # policy weights
train_config.json # train config train_config.json # train config
processor.json # processor config (if preprocessor provided)
step_*.safetensors # processor state files (if any)
training_state/ training_state/
optimizer_param_groups.json # optimizer param groups optimizer_param_groups.json # optimizer param groups
optimizer_state.safetensors # optimizer state optimizer_state.safetensors # optimizer state
@@ -95,10 +98,13 @@ def save_checkpoint(
policy (PreTrainedPolicy): The policy to save. policy (PreTrainedPolicy): The policy to save.
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None. optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
preprocessor: The preprocessor/pipeline to save. Defaults to None.
""" """
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir) policy.save_pretrained(pretrained_dir)
cfg.save_pretrained(pretrained_dir) cfg.save_pretrained(pretrained_dir)
if preprocessor is not None:
preprocessor.save_pretrained(pretrained_dir)
save_training_state(checkpoint_dir, step, optimizer, scheduler) save_training_state(checkpoint_dir, step, optimizer, scheduler)