Refactored actor.py to use the pipeline

This commit is contained in:
Michel Aractingi
2025-08-02 19:06:56 +02:00
parent e6e1edfd74
commit cfa672129e
2 changed files with 186 additions and 93 deletions
+57 -19
View File
@@ -83,6 +83,13 @@ from lerobot.utils.transition import (
move_state_dict_to_device,
move_transition_to_device,
)
from lerobot.processor.pipeline import EnvTransition, TransitionKey
from lerobot.scripts.rl.gym_manipulator import (
create_transition,
make_processors,
step_env_and_process_transition,
)
from lerobot.utils.utils import (
TimerManager,
get_safe_torch_device,
@@ -236,7 +243,8 @@ def act_with_policy(
logging.info("make_env online")
online_env = make_robot_env(cfg=cfg.env)
online_env, teleop_device = make_robot_env(cfg=cfg.env)
env_processor, action_processor = make_processors(online_env, cfg.env, cfg.policy.device)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.policy.device, log=True)
@@ -257,6 +265,13 @@ def act_with_policy(
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(transition)
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
@@ -277,7 +292,9 @@ def act_with_policy(
if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with policy_timer:
action = policy.select_action(batch=obs)
# Extract observation from transition for policy
batch_obs = transition[TransitionKey.OBSERVATION]
action = policy.select_action(batch=batch_obs)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
@@ -285,34 +302,46 @@ def act_with_policy(
else:
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# Use the new step function
new_transition, terminate_episode = step_env_and_process_transition(
env=online_env,
transition=transition,
action=action,
teleop_device=teleop_device,
env_processor=env_processor,
action_processor=action_processor,
)
# Extract values from processed transition
reward = new_transition[TransitionKey.REWARD]
done = new_transition.get(TransitionKey.DONE, False)
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
processed_action = new_transition[TransitionKey.ACTION]
sum_reward_episode += float(reward)
# Increment total steps counter for intervention rate
episode_total_steps += 1
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
if "is_intervention" in info and info["is_intervention"]:
# NOTE: The action space for demonstration before hand is with the full action space
# but sometimes for example we want to deactivate the gripper
action = info["action_intervention"]
# Check for intervention from transition info
intervention_info = new_transition[TransitionKey.INFO]
if intervention_info.get("is_intervention", False):
episode_intervention = True
# Increment intervention steps counter
episode_intervention_steps += 1
# Create transition for learner (convert to old format)
list_transition_to_send_to_learner.append(
Transition(
state=obs,
action=action,
state=transition[TransitionKey.OBSERVATION],
action=processed_action,
reward=reward,
next_state=next_obs,
next_state=new_transition[TransitionKey.OBSERVATION],
done=done,
truncated=truncated, # TODO: (azouitine) Handle truncation properly
complementary_info=info,
truncated=truncated,
complementary_info=new_transition[TransitionKey.COMPLEMENTARY_DATA],
)
)
# assign obs to the next obs and continue the rollout
obs = next_obs
# Update transition for next iteration
transition = new_transition
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
@@ -347,12 +376,21 @@ def act_with_policy(
)
)
# Reset intervention counters
# Reset intervention counters and environment
sum_reward_episode = 0.0
episode_intervention = False
episode_intervention_steps = 0
episode_total_steps = 0
# Reset environment and processors
obs, info = online_env.reset()
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(transition)
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
+129 -74
View File
@@ -404,6 +404,9 @@ class ImageCropResizeProcessor:
if observation is None:
return transition
if self.resize_size is None and not self.crop_params_dict:
return transition
new_observation = dict(observation)
# Process all image keys in the observation
@@ -777,6 +780,123 @@ def make_robot_env(cfg: EnvConfig) -> tuple[gym.Env, Any]:
return env, teleop_device
def make_processors(env, cfg):
"""
Factory function to create environment and action processors.
Args:
env: The robot environment
cfg: Configuration object containing processor parameters
Returns:
tuple: (env_processor, action_processor)
"""
env_pipeline_steps = [
ImageProcessor(),
StateProcessor(),
JointVelocityProcessor(dt=1.0 / cfg.fps),
MotorCurrentProcessor(env=env),
ImageCropResizeProcessor(
crop_params_dict=cfg.processor.crop_params_dict,
resize_size=cfg.processor.resize_size
),
TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)),
GripperPenaltyProcessor(
penalty=cfg.processor.gripper_penalty,
max_gripper_pos=cfg.processor.max_gripper_pos
),
DeviceProcessor(device=cfg.device),
]
env_processor = RobotProcessor(steps=env_pipeline_steps)
action_pipeline_steps = [
InterventionActionProcessor(
use_gripper=cfg.processor.use_gripper,
),
InverseKinematicsProcessor(
urdf_path=cfg.processor.urdf_path,
target_frame_name=cfg.processor.target_frame_name,
end_effector_step_sizes=cfg.processor.end_effector_step_sizes,
end_effector_bounds=cfg.processor.end_effector_bounds,
max_gripper_pos=cfg.processor.max_gripper_pos,
env=env,
),
]
action_processor = RobotProcessor(steps=action_pipeline_steps)
return env_processor, action_processor
def step_env_and_process_transition(
env,
transition,
action,
teleop_device,
env_processor,
action_processor,
):
"""
Execute one step with processors handling intervention and observation processing.
Args:
env: The robot environment
transition: Current transition state
action: Action to execute (will be replaced by neutral action in gym_manipulator mode)
teleop_device: Teleoperator device for getting intervention signals
env_processor: Environment processor for observations
action_processor: Action processor for handling interventions
Returns:
tuple: (new_transition, terminate_episode)
"""
# Get teleoperation action and events
teleop_action = teleop_device.get_action()
teleop_events = teleop_device.get_teleop_events()
# Create action transition
action_transition = dict(transition)
action_transition[TransitionKey.ACTION] = action
# Add teleoperation data to complementary data
action_complementary_data = action_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).copy()
action_complementary_data["teleop_action"] = teleop_action
action_complementary_data.update(teleop_events)
action_transition[TransitionKey.COMPLEMENTARY_DATA] = action_complementary_data
# Process action through action pipeline (handles intervention)
processed_action_transition = action_processor(action_transition)
# Extract processed action and metadata
processed_action = processed_action_transition[TransitionKey.ACTION]
terminate_episode = processed_action_transition.get(TransitionKey.DONE, False)
# Step environment with processed action
obs, reward, terminated, truncated, info = env.step(processed_action)
# Combine rewards from environment and action processor
reward = reward + processed_action_transition[TransitionKey.REWARD]
# Process new observation
complementary_data = {
"raw_joint_positions": info.pop("raw_joint_positions"),
**processed_action_transition[TransitionKey.COMPLEMENTARY_DATA],
}
info.update(processed_action_transition[TransitionKey.INFO])
new_transition = create_transition(
observation=obs,
action=processed_action,
reward=reward,
done=terminated or terminate_episode,
truncated=truncated,
info=info,
complementary_data=complementary_data,
)
new_transition = env_processor(new_transition)
return new_transition, terminate_episode
def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvConfig):
dt = 1.0 / cfg.fps
@@ -841,54 +961,20 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
while episode_idx < cfg.num_episodes:
step_start_time = time.perf_counter()
# Get teleoperation action and extra signals
teleop_action = teleop_device.get_action()
teleop_events = teleop_device.get_teleop_events()
# Create a neutral action (no movement)
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
if hasattr(env, "use_gripper") and env.use_gripper:
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
# Create action transition
action_transition = dict(transition)
action_transition[TransitionKey.ACTION] = neutral_action
# Add teleoperation data to complementary data
action_complementary_data = action_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).copy()
action_complementary_data["teleop_action"] = teleop_action
action_complementary_data.update(teleop_events)
action_transition[TransitionKey.COMPLEMENTARY_DATA] = action_complementary_data
# Process action through action pipeline (handles intervention)
processed_action_transition = action_processor(action_transition)
# Extract processed action and metadata
processed_action = processed_action_transition[TransitionKey.ACTION]
terminate_episode = processed_action_transition.get(TransitionKey.DONE, False)
# Step environment with processed action
obs, reward, terminated, truncated, info = env.step(processed_action)
reward = reward + processed_action_transition[TransitionKey.REWARD]
# Process new observation
complementary_data = {
"raw_joint_positions": info.pop("raw_joint_positions"),
**processed_action_transition[TransitionKey.COMPLEMENTARY_DATA],
}
info.update(processed_action_transition[TransitionKey.INFO])
transition = create_transition(
observation=obs,
action=processed_action,
reward=reward,
done=terminated or terminate_episode,
truncated=truncated,
info=info,
complementary_data=complementary_data,
# Use the new step function
transition, terminate_episode = step_env_and_process_transition(
env=env,
transition=transition,
action=neutral_action,
teleop_device=teleop_device,
env_processor=env_processor,
action_processor=action_processor,
)
transition = env_processor(transition)
terminated = transition.get(TransitionKey.DONE, False)
truncated = transition.get(TransitionKey.TRUNCATED, False)
@@ -963,38 +1049,7 @@ def replay_trajectory(env, action_processor, cfg):
@parser.wrap()
def main(cfg: EnvConfig):
env, teleop_device = make_robot_env(cfg)
env_pipeline_steps = [
ImageProcessor(),
StateProcessor(),
JointVelocityProcessor(dt=1.0 / cfg.fps),
MotorCurrentProcessor(env=env),
ImageCropResizeProcessor(
crop_params_dict=cfg.processor.crop_params_dict, resize_size=cfg.processor.resize_size
),
TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)),
GripperPenaltyProcessor(
penalty=cfg.processor.gripper_penalty, max_gripper_pos=cfg.processor.max_gripper_pos
),
DeviceProcessor(device=cfg.device),
]
env_processor = RobotProcessor(steps=env_pipeline_steps)
action_pipeline_steps = [
InterventionActionProcessor(
use_gripper=cfg.processor.use_gripper,
),
InverseKinematicsProcessor(
urdf_path=cfg.processor.urdf_path,
target_frame_name=cfg.processor.target_frame_name,
end_effector_step_sizes=cfg.processor.end_effector_step_sizes,
end_effector_bounds=cfg.processor.end_effector_bounds,
max_gripper_pos=cfg.processor.max_gripper_pos,
env=env,
),
]
action_processor = RobotProcessor(steps=action_pipeline_steps)
env_processor, action_processor = make_processors(env, cfg)
print("Environment observation space:", env.observation_space)
print("Environment action space:", env.action_space)