mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
feat(record): Integrate RobotProcessor into recording loop and update policy handling
- Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities.
This commit is contained in:
committed by
Steven Palma
parent
670a278cbc
commit
4b24f94225
@@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
@@ -53,7 +53,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
@@ -35,7 +35,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor.pipeline import EnvTransition, RobotProcessor, TransitionIndex
|
||||
from lerobot.processor.pipeline import EnvTransition, IdentityProcessor, RobotProcessor, TransitionIndex
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
@@ -128,7 +128,9 @@ def make_processor(
|
||||
if pretrained_path:
|
||||
# Load a pretrained processor
|
||||
# TODO(azouitine): Handle this case.
|
||||
raise NotImplementedError("Loading a pretrained processor is not implemented.")
|
||||
return RobotProcessor.from_pretrained(source=pretrained_path), RobotProcessor(
|
||||
steps=[IdentityProcessor()], name="post_processor"
|
||||
)
|
||||
|
||||
# Create a new processor based on policy type
|
||||
if policy_cfg.type == "tdmpc":
|
||||
|
||||
+19
-8
@@ -74,8 +74,9 @@ from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.factory import make_policy, make_processor
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -195,6 +196,7 @@ def record_loop(
|
||||
dataset: LeRobotDataset | None = None,
|
||||
teleop: Teleoperator | list[Teleoperator] | None = None,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
processor: RobotProcessor | None = None,
|
||||
control_time_s: int | None = None,
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
@@ -219,9 +221,10 @@ def record_loop(
|
||||
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
|
||||
)
|
||||
|
||||
# if policy is given it needs cleaning up
|
||||
if policy is not None:
|
||||
# Reset policy and processor if they are provided
|
||||
if policy is not None or processor is not None:
|
||||
policy.reset()
|
||||
processor.reset()
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
@@ -237,12 +240,13 @@ def record_loop(
|
||||
if policy is not None or dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||
|
||||
if policy is not None:
|
||||
if policy is not None or processor is not None:
|
||||
action_values = predict_action(
|
||||
observation_frame,
|
||||
policy,
|
||||
get_safe_torch_device(policy.config.device),
|
||||
policy.config.use_amp,
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
processor=processor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
@@ -328,6 +332,13 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
processor = None
|
||||
if cfg.policy is not None:
|
||||
processor, _ = make_processor(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
|
||||
@@ -244,7 +244,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
@@ -31,6 +31,7 @@ from termcolor import colored
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.robots import Robot
|
||||
|
||||
|
||||
@@ -101,6 +102,7 @@ def predict_action(
|
||||
observation: dict[str, np.ndarray],
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
processor: RobotProcessor,
|
||||
use_amp: bool,
|
||||
task: str | None = None,
|
||||
robot_type: str | None = None,
|
||||
@@ -122,6 +124,8 @@ def predict_action(
|
||||
observation["task"] = task if task else ""
|
||||
observation["robot_type"] = robot_type if robot_type else ""
|
||||
|
||||
observation = processor(observation)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
@@ -74,6 +74,7 @@ def save_checkpoint(
|
||||
policy: PreTrainedPolicy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None = None,
|
||||
preprocessor=None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -81,7 +82,9 @@ def save_checkpoint(
|
||||
├── pretrained_model/
|
||||
│ ├── config.json # policy config
|
||||
│ ├── model.safetensors # policy weights
|
||||
│ └── train_config.json # train config
|
||||
│ ├── train_config.json # train config
|
||||
│ ├── processor.json # processor config (if preprocessor provided)
|
||||
│ └── step_*.safetensors # processor state files (if any)
|
||||
└── training_state/
|
||||
├── optimizer_param_groups.json # optimizer param groups
|
||||
├── optimizer_state.safetensors # optimizer state
|
||||
@@ -95,10 +98,13 @@ def save_checkpoint(
|
||||
policy (PreTrainedPolicy): The policy to save.
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
cfg.save_pretrained(pretrained_dir)
|
||||
if preprocessor is not None:
|
||||
preprocessor.save_pretrained(pretrained_dir)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user