mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
Refactored actor.py to use the pipeline
This commit is contained in:
@@ -83,6 +83,13 @@ from lerobot.utils.transition import (
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
from lerobot.scripts.rl.gym_manipulator import (
|
||||
create_transition,
|
||||
make_processors,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from lerobot.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
@@ -236,7 +243,8 @@ def act_with_policy(
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(cfg=cfg.env)
|
||||
online_env, teleop_device = make_robot_env(cfg=cfg.env)
|
||||
env_processor, action_processor = make_processors(online_env, cfg.env, cfg.policy.device)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
@@ -257,6 +265,13 @@ def act_with_policy(
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
transition = env_processor(transition)
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
@@ -277,7 +292,9 @@ def act_with_policy(
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
action = policy.select_action(batch=obs)
|
||||
# Extract observation from transition for policy
|
||||
batch_obs = transition[TransitionKey.OBSERVATION]
|
||||
action = policy.select_action(batch=batch_obs)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
@@ -285,34 +302,46 @@ def act_with_policy(
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
# Use the new step function
|
||||
new_transition, terminate_episode = step_env_and_process_transition(
|
||||
env=online_env,
|
||||
transition=transition,
|
||||
action=action,
|
||||
teleop_device=teleop_device,
|
||||
env_processor=env_processor,
|
||||
action_processor=action_processor,
|
||||
)
|
||||
|
||||
# Extract values from processed transition
|
||||
reward = new_transition[TransitionKey.REWARD]
|
||||
done = new_transition.get(TransitionKey.DONE, False)
|
||||
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
|
||||
processed_action = new_transition[TransitionKey.ACTION]
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
episode_total_steps += 1
|
||||
|
||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
# Check for intervention from transition info
|
||||
intervention_info = new_transition[TransitionKey.INFO]
|
||||
if intervention_info.get("is_intervention", False):
|
||||
episode_intervention = True
|
||||
# Increment intervention steps counter
|
||||
episode_intervention_steps += 1
|
||||
|
||||
# Create transition for learner (convert to old format)
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
state=transition[TransitionKey.OBSERVATION],
|
||||
action=processed_action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
next_state=new_transition[TransitionKey.OBSERVATION],
|
||||
done=done,
|
||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||
complementary_info=info,
|
||||
truncated=truncated,
|
||||
complementary_info=new_transition[TransitionKey.COMPLEMENTARY_DATA],
|
||||
)
|
||||
)
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
# Update transition for next iteration
|
||||
transition = new_transition
|
||||
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
@@ -347,12 +376,21 @@ def act_with_policy(
|
||||
)
|
||||
)
|
||||
|
||||
# Reset intervention counters
|
||||
# Reset intervention counters and environment
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = online_env.reset()
|
||||
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
transition = env_processor(transition)
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
|
||||
@@ -404,6 +404,9 @@ class ImageCropResizeProcessor:
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
if self.resize_size is None and not self.crop_params_dict:
|
||||
return transition
|
||||
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Process all image keys in the observation
|
||||
@@ -777,6 +780,123 @@ def make_robot_env(cfg: EnvConfig) -> tuple[gym.Env, Any]:
|
||||
return env, teleop_device
|
||||
|
||||
|
||||
def make_processors(env, cfg):
|
||||
"""
|
||||
Factory function to create environment and action processors.
|
||||
|
||||
Args:
|
||||
env: The robot environment
|
||||
cfg: Configuration object containing processor parameters
|
||||
|
||||
Returns:
|
||||
tuple: (env_processor, action_processor)
|
||||
"""
|
||||
env_pipeline_steps = [
|
||||
ImageProcessor(),
|
||||
StateProcessor(),
|
||||
JointVelocityProcessor(dt=1.0 / cfg.fps),
|
||||
MotorCurrentProcessor(env=env),
|
||||
ImageCropResizeProcessor(
|
||||
crop_params_dict=cfg.processor.crop_params_dict,
|
||||
resize_size=cfg.processor.resize_size
|
||||
),
|
||||
TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)),
|
||||
GripperPenaltyProcessor(
|
||||
penalty=cfg.processor.gripper_penalty,
|
||||
max_gripper_pos=cfg.processor.max_gripper_pos
|
||||
),
|
||||
DeviceProcessor(device=cfg.device),
|
||||
]
|
||||
env_processor = RobotProcessor(steps=env_pipeline_steps)
|
||||
|
||||
action_pipeline_steps = [
|
||||
InterventionActionProcessor(
|
||||
use_gripper=cfg.processor.use_gripper,
|
||||
),
|
||||
InverseKinematicsProcessor(
|
||||
urdf_path=cfg.processor.urdf_path,
|
||||
target_frame_name=cfg.processor.target_frame_name,
|
||||
end_effector_step_sizes=cfg.processor.end_effector_step_sizes,
|
||||
end_effector_bounds=cfg.processor.end_effector_bounds,
|
||||
max_gripper_pos=cfg.processor.max_gripper_pos,
|
||||
env=env,
|
||||
),
|
||||
]
|
||||
action_processor = RobotProcessor(steps=action_pipeline_steps)
|
||||
|
||||
return env_processor, action_processor
|
||||
|
||||
|
||||
def step_env_and_process_transition(
|
||||
env,
|
||||
transition,
|
||||
action,
|
||||
teleop_device,
|
||||
env_processor,
|
||||
action_processor,
|
||||
):
|
||||
"""
|
||||
Execute one step with processors handling intervention and observation processing.
|
||||
|
||||
Args:
|
||||
env: The robot environment
|
||||
transition: Current transition state
|
||||
action: Action to execute (will be replaced by neutral action in gym_manipulator mode)
|
||||
teleop_device: Teleoperator device for getting intervention signals
|
||||
env_processor: Environment processor for observations
|
||||
action_processor: Action processor for handling interventions
|
||||
|
||||
Returns:
|
||||
tuple: (new_transition, terminate_episode)
|
||||
"""
|
||||
# Get teleoperation action and events
|
||||
teleop_action = teleop_device.get_action()
|
||||
teleop_events = teleop_device.get_teleop_events()
|
||||
|
||||
# Create action transition
|
||||
action_transition = dict(transition)
|
||||
action_transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Add teleoperation data to complementary data
|
||||
action_complementary_data = action_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).copy()
|
||||
action_complementary_data["teleop_action"] = teleop_action
|
||||
action_complementary_data.update(teleop_events)
|
||||
action_transition[TransitionKey.COMPLEMENTARY_DATA] = action_complementary_data
|
||||
|
||||
# Process action through action pipeline (handles intervention)
|
||||
processed_action_transition = action_processor(action_transition)
|
||||
|
||||
# Extract processed action and metadata
|
||||
processed_action = processed_action_transition[TransitionKey.ACTION]
|
||||
terminate_episode = processed_action_transition.get(TransitionKey.DONE, False)
|
||||
|
||||
# Step environment with processed action
|
||||
obs, reward, terminated, truncated, info = env.step(processed_action)
|
||||
|
||||
# Combine rewards from environment and action processor
|
||||
reward = reward + processed_action_transition[TransitionKey.REWARD]
|
||||
|
||||
# Process new observation
|
||||
complementary_data = {
|
||||
"raw_joint_positions": info.pop("raw_joint_positions"),
|
||||
**processed_action_transition[TransitionKey.COMPLEMENTARY_DATA],
|
||||
}
|
||||
info.update(processed_action_transition[TransitionKey.INFO])
|
||||
|
||||
new_transition = create_transition(
|
||||
observation=obs,
|
||||
action=processed_action,
|
||||
reward=reward,
|
||||
done=terminated or terminate_episode,
|
||||
truncated=truncated,
|
||||
info=info,
|
||||
complementary_data=complementary_data,
|
||||
)
|
||||
new_transition = env_processor(new_transition)
|
||||
|
||||
return new_transition, terminate_episode
|
||||
|
||||
|
||||
def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvConfig):
|
||||
dt = 1.0 / cfg.fps
|
||||
|
||||
@@ -841,54 +961,20 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
||||
while episode_idx < cfg.num_episodes:
|
||||
step_start_time = time.perf_counter()
|
||||
|
||||
# Get teleoperation action and extra signals
|
||||
teleop_action = teleop_device.get_action()
|
||||
teleop_events = teleop_device.get_teleop_events()
|
||||
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if hasattr(env, "use_gripper") and env.use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||
|
||||
# Create action transition
|
||||
action_transition = dict(transition)
|
||||
action_transition[TransitionKey.ACTION] = neutral_action
|
||||
|
||||
# Add teleoperation data to complementary data
|
||||
action_complementary_data = action_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).copy()
|
||||
action_complementary_data["teleop_action"] = teleop_action
|
||||
action_complementary_data.update(teleop_events)
|
||||
action_transition[TransitionKey.COMPLEMENTARY_DATA] = action_complementary_data
|
||||
|
||||
# Process action through action pipeline (handles intervention)
|
||||
processed_action_transition = action_processor(action_transition)
|
||||
|
||||
# Extract processed action and metadata
|
||||
processed_action = processed_action_transition[TransitionKey.ACTION]
|
||||
terminate_episode = processed_action_transition.get(TransitionKey.DONE, False)
|
||||
|
||||
# Step environment with processed action
|
||||
obs, reward, terminated, truncated, info = env.step(processed_action)
|
||||
|
||||
reward = reward + processed_action_transition[TransitionKey.REWARD]
|
||||
|
||||
# Process new observation
|
||||
complementary_data = {
|
||||
"raw_joint_positions": info.pop("raw_joint_positions"),
|
||||
**processed_action_transition[TransitionKey.COMPLEMENTARY_DATA],
|
||||
}
|
||||
info.update(processed_action_transition[TransitionKey.INFO])
|
||||
|
||||
transition = create_transition(
|
||||
observation=obs,
|
||||
action=processed_action,
|
||||
reward=reward,
|
||||
done=terminated or terminate_episode,
|
||||
truncated=truncated,
|
||||
info=info,
|
||||
complementary_data=complementary_data,
|
||||
# Use the new step function
|
||||
transition, terminate_episode = step_env_and_process_transition(
|
||||
env=env,
|
||||
transition=transition,
|
||||
action=neutral_action,
|
||||
teleop_device=teleop_device,
|
||||
env_processor=env_processor,
|
||||
action_processor=action_processor,
|
||||
)
|
||||
transition = env_processor(transition)
|
||||
terminated = transition.get(TransitionKey.DONE, False)
|
||||
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
@@ -963,38 +1049,7 @@ def replay_trajectory(env, action_processor, cfg):
|
||||
@parser.wrap()
|
||||
def main(cfg: EnvConfig):
|
||||
env, teleop_device = make_robot_env(cfg)
|
||||
env_pipeline_steps = [
|
||||
ImageProcessor(),
|
||||
StateProcessor(),
|
||||
JointVelocityProcessor(dt=1.0 / cfg.fps),
|
||||
MotorCurrentProcessor(env=env),
|
||||
ImageCropResizeProcessor(
|
||||
crop_params_dict=cfg.processor.crop_params_dict, resize_size=cfg.processor.resize_size
|
||||
),
|
||||
TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)),
|
||||
GripperPenaltyProcessor(
|
||||
penalty=cfg.processor.gripper_penalty, max_gripper_pos=cfg.processor.max_gripper_pos
|
||||
),
|
||||
DeviceProcessor(device=cfg.device),
|
||||
]
|
||||
|
||||
env_processor = RobotProcessor(steps=env_pipeline_steps)
|
||||
|
||||
action_pipeline_steps = [
|
||||
InterventionActionProcessor(
|
||||
use_gripper=cfg.processor.use_gripper,
|
||||
),
|
||||
InverseKinematicsProcessor(
|
||||
urdf_path=cfg.processor.urdf_path,
|
||||
target_frame_name=cfg.processor.target_frame_name,
|
||||
end_effector_step_sizes=cfg.processor.end_effector_step_sizes,
|
||||
end_effector_bounds=cfg.processor.end_effector_bounds,
|
||||
max_gripper_pos=cfg.processor.max_gripper_pos,
|
||||
env=env,
|
||||
),
|
||||
]
|
||||
|
||||
action_processor = RobotProcessor(steps=action_pipeline_steps)
|
||||
env_processor, action_processor = make_processors(env, cfg)
|
||||
|
||||
print("Environment observation space:", env.observation_space)
|
||||
print("Environment action space:", env.action_space)
|
||||
|
||||
Reference in New Issue
Block a user