mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
RL works at this commit - fixed actor.py and bugs in gym_manipulator
This commit is contained in:
@@ -243,7 +243,7 @@ def act_with_policy(
|
|||||||
logging.info("make_env online")
|
logging.info("make_env online")
|
||||||
|
|
||||||
online_env, teleop_device = make_robot_env(cfg=cfg.env)
|
online_env, teleop_device = make_robot_env(cfg=cfg.env)
|
||||||
env_processor, action_processor = make_processors(online_env, cfg.env, cfg.policy.device)
|
env_processor, action_processor = make_processors(online_env, cfg.env)
|
||||||
|
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
@@ -288,18 +288,15 @@ def act_with_policy(
|
|||||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||||
return
|
return
|
||||||
|
|
||||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
observation = transition[TransitionKey.OBSERVATION]
|
||||||
# Time policy inference and check if it meets FPS requirement
|
|
||||||
with policy_timer:
|
|
||||||
# Extract observation from transition for policy
|
|
||||||
batch_obs = transition[TransitionKey.OBSERVATION]
|
|
||||||
action = policy.select_action(batch=batch_obs)
|
|
||||||
policy_fps = policy_timer.fps_last
|
|
||||||
|
|
||||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
# Time policy inference and check if it meets FPS requirement
|
||||||
|
with policy_timer:
|
||||||
|
# Extract observation from transition for policy
|
||||||
|
action = policy.select_action(batch=observation)
|
||||||
|
policy_fps = policy_timer.fps_last
|
||||||
|
|
||||||
else:
|
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||||
action = online_env.action_space.sample()
|
|
||||||
|
|
||||||
# Use the new step function
|
# Use the new step function
|
||||||
new_transition, terminate_episode = step_env_and_process_transition(
|
new_transition, terminate_episode = step_env_and_process_transition(
|
||||||
@@ -312,10 +309,11 @@ def act_with_policy(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Extract values from processed transition
|
# Extract values from processed transition
|
||||||
|
next_observation = new_transition[TransitionKey.OBSERVATION]
|
||||||
|
executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
|
||||||
reward = new_transition[TransitionKey.REWARD]
|
reward = new_transition[TransitionKey.REWARD]
|
||||||
done = new_transition.get(TransitionKey.DONE, False)
|
done = new_transition.get(TransitionKey.DONE, False)
|
||||||
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
|
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
|
||||||
processed_action = new_transition[TransitionKey.ACTION]
|
|
||||||
|
|
||||||
sum_reward_episode += float(reward)
|
sum_reward_episode += float(reward)
|
||||||
episode_total_steps += 1
|
episode_total_steps += 1
|
||||||
@@ -329,13 +327,13 @@ def act_with_policy(
|
|||||||
# Create transition for learner (convert to old format)
|
# Create transition for learner (convert to old format)
|
||||||
list_transition_to_send_to_learner.append(
|
list_transition_to_send_to_learner.append(
|
||||||
Transition(
|
Transition(
|
||||||
state=transition[TransitionKey.OBSERVATION],
|
state=observation,
|
||||||
action=processed_action,
|
action=executed_action,
|
||||||
reward=reward,
|
reward=reward,
|
||||||
next_state=new_transition[TransitionKey.OBSERVATION],
|
next_state=next_observation,
|
||||||
done=done,
|
done=done,
|
||||||
truncated=truncated,
|
truncated=truncated,
|
||||||
complementary_info=new_transition[TransitionKey.COMPLEMENTARY_DATA],
|
complementary_info={}, # new_transition[TransitionKey.COMPLEMENTARY_DATA],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -800,11 +800,15 @@ def make_processors(env, cfg):
|
|||||||
crop_params_dict=cfg.processor.crop_params_dict, resize_size=cfg.processor.resize_size
|
crop_params_dict=cfg.processor.crop_params_dict, resize_size=cfg.processor.resize_size
|
||||||
),
|
),
|
||||||
TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)),
|
TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)),
|
||||||
GripperPenaltyProcessor(
|
|
||||||
penalty=cfg.processor.gripper_penalty, max_gripper_pos=cfg.processor.max_gripper_pos
|
|
||||||
),
|
|
||||||
DeviceProcessor(device=cfg.device),
|
|
||||||
]
|
]
|
||||||
|
if cfg.processor.use_gripper:
|
||||||
|
env_pipeline_steps.append(
|
||||||
|
GripperPenaltyProcessor(
|
||||||
|
penalty=cfg.processor.gripper_penalty, max_gripper_pos=cfg.processor.max_gripper_pos
|
||||||
|
)
|
||||||
|
)
|
||||||
|
env_pipeline_steps.append(DeviceProcessor(device=cfg.device))
|
||||||
|
|
||||||
env_processor = RobotProcessor(steps=env_pipeline_steps)
|
env_processor = RobotProcessor(steps=env_pipeline_steps)
|
||||||
|
|
||||||
action_pipeline_steps = [
|
action_pipeline_steps = [
|
||||||
@@ -920,12 +924,13 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
|||||||
"action": action_features,
|
"action": action_features,
|
||||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||||
"complementary_info.discrete_penalty": {
|
}
|
||||||
|
if cfg.processor.use_gripper:
|
||||||
|
features["complementary_info.discrete_penalty"] = {
|
||||||
"dtype": "float32",
|
"dtype": "float32",
|
||||||
"shape": (1,),
|
"shape": (1,),
|
||||||
"names": ["discrete_penalty"],
|
"names": ["discrete_penalty"],
|
||||||
},
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for key, value in transition[TransitionKey.OBSERVATION].items():
|
for key, value in transition[TransitionKey.OBSERVATION].items():
|
||||||
if key == "observation.state":
|
if key == "observation.state":
|
||||||
@@ -977,16 +982,17 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
|||||||
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
||||||
|
|
||||||
if cfg.mode == "record":
|
if cfg.mode == "record":
|
||||||
observations = {k: v.squeeze(0) for k, v in transition[TransitionKey.OBSERVATION].items()}
|
observations = {k: v.squeeze(0).cpu() for k, v in transition[TransitionKey.OBSERVATION].items()}
|
||||||
frame = {
|
frame = {
|
||||||
**observations,
|
**observations,
|
||||||
"action": transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"],
|
"action": transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"].cpu(),
|
||||||
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||||
"next.done": np.array([terminated or truncated], dtype=bool),
|
"next.done": np.array([terminated or truncated], dtype=bool),
|
||||||
"complementary_info.discrete_penalty": np.array(
|
|
||||||
[transition[TransitionKey.COMPLEMENTARY_DATA]["discrete_penalty"]], dtype=np.float32
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
if cfg.processor.use_gripper:
|
||||||
|
frame["complementary_info.discrete_penalty"] = np.array(
|
||||||
|
[transition[TransitionKey.COMPLEMENTARY_DATA]["discrete_penalty"]], dtype=np.float32
|
||||||
|
)
|
||||||
dataset.add_frame(frame, task=cfg.task)
|
dataset.add_frame(frame, task=cfg.task)
|
||||||
|
|
||||||
episode_step += 1
|
episode_step += 1
|
||||||
@@ -997,16 +1003,6 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
|
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset for new episode
|
|
||||||
obs, info = env.reset()
|
|
||||||
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
|
|
||||||
env_processor.reset()
|
|
||||||
action_processor.reset()
|
|
||||||
|
|
||||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
|
||||||
transition = env_processor(transition)
|
|
||||||
|
|
||||||
episode_step = 0
|
episode_step = 0
|
||||||
episode_idx += 1
|
episode_idx += 1
|
||||||
|
|
||||||
@@ -1019,6 +1015,15 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
|||||||
logging.info(f"Saving episode {episode_idx}")
|
logging.info(f"Saving episode {episode_idx}")
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
|
# Reset for new episode
|
||||||
|
obs, info = env.reset()
|
||||||
|
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
|
||||||
|
env_processor.reset()
|
||||||
|
action_processor.reset()
|
||||||
|
|
||||||
|
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||||
|
transition = env_processor(transition)
|
||||||
|
|
||||||
# Maintain fps timing
|
# Maintain fps timing
|
||||||
busy_wait(dt - (time.perf_counter() - step_start_time))
|
busy_wait(dt - (time.perf_counter() - step_start_time))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user