mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
RL works at this commit - fixed actor.py and bugs in gym_manipulator
This commit is contained in:
@@ -243,7 +243,7 @@ def act_with_policy(
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env, teleop_device = make_robot_env(cfg=cfg.env)
|
||||
env_processor, action_processor = make_processors(online_env, cfg.env, cfg.policy.device)
|
||||
env_processor, action_processor = make_processors(online_env, cfg.env)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
@@ -288,18 +288,15 @@ def act_with_policy(
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
batch_obs = transition[TransitionKey.OBSERVATION]
|
||||
action = policy.select_action(batch=batch_obs)
|
||||
policy_fps = policy_timer.fps_last
|
||||
observation = transition[TransitionKey.OBSERVATION]
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
# Use the new step function
|
||||
new_transition, terminate_episode = step_env_and_process_transition(
|
||||
@@ -312,10 +309,11 @@ def act_with_policy(
|
||||
)
|
||||
|
||||
# Extract values from processed transition
|
||||
next_observation = new_transition[TransitionKey.OBSERVATION]
|
||||
executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
|
||||
reward = new_transition[TransitionKey.REWARD]
|
||||
done = new_transition.get(TransitionKey.DONE, False)
|
||||
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
|
||||
processed_action = new_transition[TransitionKey.ACTION]
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
episode_total_steps += 1
|
||||
@@ -329,13 +327,13 @@ def act_with_policy(
|
||||
# Create transition for learner (convert to old format)
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=transition[TransitionKey.OBSERVATION],
|
||||
action=processed_action,
|
||||
state=observation,
|
||||
action=executed_action,
|
||||
reward=reward,
|
||||
next_state=new_transition[TransitionKey.OBSERVATION],
|
||||
next_state=next_observation,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
complementary_info=new_transition[TransitionKey.COMPLEMENTARY_DATA],
|
||||
complementary_info={}, # new_transition[TransitionKey.COMPLEMENTARY_DATA],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -800,11 +800,15 @@ def make_processors(env, cfg):
|
||||
crop_params_dict=cfg.processor.crop_params_dict, resize_size=cfg.processor.resize_size
|
||||
),
|
||||
TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)),
|
||||
GripperPenaltyProcessor(
|
||||
penalty=cfg.processor.gripper_penalty, max_gripper_pos=cfg.processor.max_gripper_pos
|
||||
),
|
||||
DeviceProcessor(device=cfg.device),
|
||||
]
|
||||
if cfg.processor.use_gripper:
|
||||
env_pipeline_steps.append(
|
||||
GripperPenaltyProcessor(
|
||||
penalty=cfg.processor.gripper_penalty, max_gripper_pos=cfg.processor.max_gripper_pos
|
||||
)
|
||||
)
|
||||
env_pipeline_steps.append(DeviceProcessor(device=cfg.device))
|
||||
|
||||
env_processor = RobotProcessor(steps=env_pipeline_steps)
|
||||
|
||||
action_pipeline_steps = [
|
||||
@@ -920,12 +924,13 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
||||
"action": action_features,
|
||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
"complementary_info.discrete_penalty": {
|
||||
}
|
||||
if cfg.processor.use_gripper:
|
||||
features["complementary_info.discrete_penalty"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": ["discrete_penalty"],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
for key, value in transition[TransitionKey.OBSERVATION].items():
|
||||
if key == "observation.state":
|
||||
@@ -977,16 +982,17 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
||||
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
if cfg.mode == "record":
|
||||
observations = {k: v.squeeze(0) for k, v in transition[TransitionKey.OBSERVATION].items()}
|
||||
observations = {k: v.squeeze(0).cpu() for k, v in transition[TransitionKey.OBSERVATION].items()}
|
||||
frame = {
|
||||
**observations,
|
||||
"action": transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"],
|
||||
"action": transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"].cpu(),
|
||||
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||
"next.done": np.array([terminated or truncated], dtype=bool),
|
||||
"complementary_info.discrete_penalty": np.array(
|
||||
[transition[TransitionKey.COMPLEMENTARY_DATA]["discrete_penalty"]], dtype=np.float32
|
||||
),
|
||||
}
|
||||
if cfg.processor.use_gripper:
|
||||
frame["complementary_info.discrete_penalty"] = np.array(
|
||||
[transition[TransitionKey.COMPLEMENTARY_DATA]["discrete_penalty"]], dtype=np.float32
|
||||
)
|
||||
dataset.add_frame(frame, task=cfg.task)
|
||||
|
||||
episode_step += 1
|
||||
@@ -997,16 +1003,6 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
||||
logging.info(
|
||||
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
|
||||
)
|
||||
|
||||
# Reset for new episode
|
||||
obs, info = env.reset()
|
||||
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
transition = env_processor(transition)
|
||||
|
||||
episode_step = 0
|
||||
episode_idx += 1
|
||||
|
||||
@@ -1019,6 +1015,15 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
||||
logging.info(f"Saving episode {episode_idx}")
|
||||
dataset.save_episode()
|
||||
|
||||
# Reset for new episode
|
||||
obs, info = env.reset()
|
||||
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
transition = env_processor(transition)
|
||||
|
||||
# Maintain fps timing
|
||||
busy_wait(dt - (time.perf_counter() - step_start_time))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user