feat(record): Integrate RobotProcessor into recording loop and update policy handling

- Added support for RobotProcessor in the record_loop function to enhance data processing capabilities.
- Updated the logic to reset both policy and processor when provided, ensuring proper state management.
- Modified action prediction to utilize the processor, improving the overall functionality of the recording process.
- Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities.
This commit is contained in:
AdilZouitine
2025-07-17 13:10:15 +02:00
committed by Steven Palma
parent 670a278cbc
commit 4b24f94225
6 changed files with 36 additions and 14 deletions
+1 -2
View File
@@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
@@ -53,7 +53,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
+4 -2
View File
@@ -35,7 +35,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor.pipeline import EnvTransition, RobotProcessor, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, IdentityProcessor, RobotProcessor, TransitionIndex
def get_policy_class(name: str) -> PreTrainedPolicy:
@@ -128,7 +128,9 @@ def make_processor(
if pretrained_path:
# Load a pretrained processor
# TODO(azouitine): Handle this case.
raise NotImplementedError("Loading a pretrained processor is not implemented.")
return RobotProcessor.from_pretrained(source=pretrained_path), RobotProcessor(
steps=[IdentityProcessor()], name="post_processor"
)
# Create a new processor based on policy type
if policy_cfg.type == "tdmpc":
+19 -8
View File
@@ -74,8 +74,9 @@ from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy
from lerobot.policies.factory import make_policy, make_processor
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -195,6 +196,7 @@ def record_loop(
dataset: LeRobotDataset | None = None,
teleop: Teleoperator | list[Teleoperator] | None = None,
policy: PreTrainedPolicy | None = None,
processor: RobotProcessor | None = None,
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
@@ -219,9 +221,10 @@ def record_loop(
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
)
# if policy is given it needs cleaning up
if policy is not None:
# Reset policy and processor if they are provided
if policy is not None or processor is not None:
policy.reset()
processor.reset()
timestamp = 0
start_episode_t = time.perf_counter()
@@ -237,12 +240,13 @@ def record_loop(
if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
if policy is not None:
if policy is not None or processor is not None:
action_values = predict_action(
observation_frame,
policy,
get_safe_torch_device(policy.config.device),
policy.config.use_amp,
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
processor=processor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
@@ -328,6 +332,13 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
processor = None
if cfg.policy is not None:
processor, _ = make_processor(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=dataset.meta.stats,
)
robot.connect()
if teleop is not None:
+1 -1
View File
@@ -244,7 +244,7 @@ def train(cfg: TrainPipelineConfig):
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
+4
View File
@@ -31,6 +31,7 @@ from termcolor import colored
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor
from lerobot.robots import Robot
@@ -101,6 +102,7 @@ def predict_action(
observation: dict[str, np.ndarray],
policy: PreTrainedPolicy,
device: torch.device,
processor: RobotProcessor,
use_amp: bool,
task: str | None = None,
robot_type: str | None = None,
@@ -122,6 +124,8 @@ def predict_action(
observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else ""
observation = processor(observation)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
+7 -1
View File
@@ -74,6 +74,7 @@ def save_checkpoint(
policy: PreTrainedPolicy,
optimizer: Optimizer,
scheduler: LRScheduler | None = None,
preprocessor=None,
) -> None:
"""This function creates the following directory structure:
@@ -81,7 +82,9 @@ def save_checkpoint(
pretrained_model/
config.json # policy config
model.safetensors # policy weights
train_config.json # train config
train_config.json # train config
processor.json # processor config (if preprocessor provided)
step_*.safetensors # processor state files (if any)
training_state/
optimizer_param_groups.json # optimizer param groups
optimizer_state.safetensors # optimizer state
@@ -95,10 +98,13 @@ def save_checkpoint(
policy (PreTrainedPolicy): The policy to save.
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
preprocessor: The preprocessor/pipeline to save. Defaults to None.
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
cfg.save_pretrained(pretrained_dir)
if preprocessor is not None:
preprocessor.save_pretrained(pretrained_dir)
save_training_state(checkpoint_dir, step, optimizer, scheduler)