mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
feat(record): Integrate RobotProcessor into recording loop and update policy handling
- Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities.
This commit is contained in:
committed by
Steven Palma
parent
670a278cbc
commit
4b24f94225
@@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download
|
|||||||
from huggingface_hub.constants import CONFIG_NAME
|
from huggingface_hub.constants import CONFIG_NAME
|
||||||
from huggingface_hub.errors import HfHubHTTPError
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.constants import ACTION, OBS_STATE
|
from lerobot.constants import ACTION, OBS_STATE
|
||||||
from lerobot.optim.optimizers import OptimizerConfig
|
from lerobot.optim.optimizers import OptimizerConfig
|
||||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||||
@@ -53,7 +53,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
n_obs_steps: int = 1
|
n_obs_steps: int = 1
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
|
|
||||||
|
|
||||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
|||||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from lerobot.processor.pipeline import EnvTransition, RobotProcessor, TransitionIndex
|
from lerobot.processor.pipeline import EnvTransition, IdentityProcessor, RobotProcessor, TransitionIndex
|
||||||
|
|
||||||
|
|
||||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||||
@@ -128,7 +128,9 @@ def make_processor(
|
|||||||
if pretrained_path:
|
if pretrained_path:
|
||||||
# Load a pretrained processor
|
# Load a pretrained processor
|
||||||
# TODO(azouitine): Handle this case.
|
# TODO(azouitine): Handle this case.
|
||||||
raise NotImplementedError("Loading a pretrained processor is not implemented.")
|
return RobotProcessor.from_pretrained(source=pretrained_path), RobotProcessor(
|
||||||
|
steps=[IdentityProcessor()], name="post_processor"
|
||||||
|
)
|
||||||
|
|
||||||
# Create a new processor based on policy type
|
# Create a new processor based on policy type
|
||||||
if policy_cfg.type == "tdmpc":
|
if policy_cfg.type == "tdmpc":
|
||||||
|
|||||||
+19
-8
@@ -74,8 +74,9 @@ from lerobot.datasets.image_writer import safe_stop_image_writer
|
|||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||||
from lerobot.policies.factory import make_policy
|
from lerobot.policies.factory import make_policy, make_processor
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.processor import RobotProcessor
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import ( # noqa: F401
|
||||||
Robot,
|
Robot,
|
||||||
RobotConfig,
|
RobotConfig,
|
||||||
@@ -195,6 +196,7 @@ def record_loop(
|
|||||||
dataset: LeRobotDataset | None = None,
|
dataset: LeRobotDataset | None = None,
|
||||||
teleop: Teleoperator | list[Teleoperator] | None = None,
|
teleop: Teleoperator | list[Teleoperator] | None = None,
|
||||||
policy: PreTrainedPolicy | None = None,
|
policy: PreTrainedPolicy | None = None,
|
||||||
|
processor: RobotProcessor | None = None,
|
||||||
control_time_s: int | None = None,
|
control_time_s: int | None = None,
|
||||||
single_task: str | None = None,
|
single_task: str | None = None,
|
||||||
display_data: bool = False,
|
display_data: bool = False,
|
||||||
@@ -219,9 +221,10 @@ def record_loop(
|
|||||||
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
|
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
|
||||||
)
|
)
|
||||||
|
|
||||||
# if policy is given it needs cleaning up
|
# Reset policy and processor if they are provided
|
||||||
if policy is not None:
|
if policy is not None or processor is not None:
|
||||||
policy.reset()
|
policy.reset()
|
||||||
|
processor.reset()
|
||||||
|
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
@@ -237,12 +240,13 @@ def record_loop(
|
|||||||
if policy is not None or dataset is not None:
|
if policy is not None or dataset is not None:
|
||||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||||
|
|
||||||
if policy is not None:
|
if policy is not None or processor is not None:
|
||||||
action_values = predict_action(
|
action_values = predict_action(
|
||||||
observation_frame,
|
observation=observation_frame,
|
||||||
policy,
|
policy=policy,
|
||||||
get_safe_torch_device(policy.config.device),
|
device=get_safe_torch_device(policy.config.device),
|
||||||
policy.config.use_amp,
|
processor=processor,
|
||||||
|
use_amp=policy.config.use_amp,
|
||||||
task=single_task,
|
task=single_task,
|
||||||
robot_type=robot.robot_type,
|
robot_type=robot.robot_type,
|
||||||
)
|
)
|
||||||
@@ -328,6 +332,13 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
|||||||
|
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||||
|
processor = None
|
||||||
|
if cfg.policy is not None:
|
||||||
|
processor, _ = make_processor(
|
||||||
|
policy_cfg=cfg.policy,
|
||||||
|
pretrained_path=cfg.policy.pretrained_path,
|
||||||
|
dataset_stats=dataset.meta.stats,
|
||||||
|
)
|
||||||
|
|
||||||
robot.connect()
|
robot.connect()
|
||||||
if teleop is not None:
|
if teleop is not None:
|
||||||
|
|||||||
@@ -244,7 +244,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
if cfg.save_checkpoint and is_saving_step:
|
if cfg.save_checkpoint and is_saving_step:
|
||||||
logging.info(f"Checkpoint policy after step {step}")
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor)
|
||||||
update_last_checkpoint(checkpoint_dir)
|
update_last_checkpoint(checkpoint_dir)
|
||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
wandb_logger.log_policy(checkpoint_dir)
|
wandb_logger.log_policy(checkpoint_dir)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from termcolor import colored
|
|||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.processor import RobotProcessor
|
||||||
from lerobot.robots import Robot
|
from lerobot.robots import Robot
|
||||||
|
|
||||||
|
|
||||||
@@ -101,6 +102,7 @@ def predict_action(
|
|||||||
observation: dict[str, np.ndarray],
|
observation: dict[str, np.ndarray],
|
||||||
policy: PreTrainedPolicy,
|
policy: PreTrainedPolicy,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
processor: RobotProcessor,
|
||||||
use_amp: bool,
|
use_amp: bool,
|
||||||
task: str | None = None,
|
task: str | None = None,
|
||||||
robot_type: str | None = None,
|
robot_type: str | None = None,
|
||||||
@@ -122,6 +124,8 @@ def predict_action(
|
|||||||
observation["task"] = task if task else ""
|
observation["task"] = task if task else ""
|
||||||
observation["robot_type"] = robot_type if robot_type else ""
|
observation["robot_type"] = robot_type if robot_type else ""
|
||||||
|
|
||||||
|
observation = processor(observation)
|
||||||
|
|
||||||
# Compute the next action with the policy
|
# Compute the next action with the policy
|
||||||
# based on the current observation
|
# based on the current observation
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation)
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ def save_checkpoint(
|
|||||||
policy: PreTrainedPolicy,
|
policy: PreTrainedPolicy,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
scheduler: LRScheduler | None = None,
|
scheduler: LRScheduler | None = None,
|
||||||
|
preprocessor=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""This function creates the following directory structure:
|
"""This function creates the following directory structure:
|
||||||
|
|
||||||
@@ -81,7 +82,9 @@ def save_checkpoint(
|
|||||||
├── pretrained_model/
|
├── pretrained_model/
|
||||||
│ ├── config.json # policy config
|
│ ├── config.json # policy config
|
||||||
│ ├── model.safetensors # policy weights
|
│ ├── model.safetensors # policy weights
|
||||||
│ └── train_config.json # train config
|
│ ├── train_config.json # train config
|
||||||
|
│ ├── processor.json # processor config (if preprocessor provided)
|
||||||
|
│ └── step_*.safetensors # processor state files (if any)
|
||||||
└── training_state/
|
└── training_state/
|
||||||
├── optimizer_param_groups.json # optimizer param groups
|
├── optimizer_param_groups.json # optimizer param groups
|
||||||
├── optimizer_state.safetensors # optimizer state
|
├── optimizer_state.safetensors # optimizer state
|
||||||
@@ -95,10 +98,13 @@ def save_checkpoint(
|
|||||||
policy (PreTrainedPolicy): The policy to save.
|
policy (PreTrainedPolicy): The policy to save.
|
||||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||||
|
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||||
"""
|
"""
|
||||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||||
policy.save_pretrained(pretrained_dir)
|
policy.save_pretrained(pretrained_dir)
|
||||||
cfg.save_pretrained(pretrained_dir)
|
cfg.save_pretrained(pretrained_dir)
|
||||||
|
if preprocessor is not None:
|
||||||
|
preprocessor.save_pretrained(pretrained_dir)
|
||||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user