RL works at this commit - fixed actor.py and bugs in gym_manipulator

This commit is contained in:
Michel Aractingi
2025-08-03 23:21:13 +02:00
parent ff38a51df9
commit f49280e89b
2 changed files with 41 additions and 38 deletions
+14 -16
View File
@@ -243,7 +243,7 @@ def act_with_policy(
logging.info("make_env online")
online_env, teleop_device = make_robot_env(cfg=cfg.env)
env_processor, action_processor = make_processors(online_env, cfg.env, cfg.policy.device)
env_processor, action_processor = make_processors(online_env, cfg.env)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.policy.device, log=True)
@@ -288,18 +288,15 @@ def act_with_policy(
logging.info("[ACTOR] Shutting down act_with_policy")
return
if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with policy_timer:
# Extract observation from transition for policy
batch_obs = transition[TransitionKey.OBSERVATION]
action = policy.select_action(batch=batch_obs)
policy_fps = policy_timer.fps_last
observation = transition[TransitionKey.OBSERVATION]
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
# Time policy inference and check if it meets FPS requirement
with policy_timer:
# Extract observation from transition for policy
action = policy.select_action(batch=observation)
policy_fps = policy_timer.fps_last
else:
action = online_env.action_space.sample()
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
# Use the new step function
new_transition, terminate_episode = step_env_and_process_transition(
@@ -312,10 +309,11 @@ def act_with_policy(
)
# Extract values from processed transition
next_observation = new_transition[TransitionKey.OBSERVATION]
executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
reward = new_transition[TransitionKey.REWARD]
done = new_transition.get(TransitionKey.DONE, False)
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
processed_action = new_transition[TransitionKey.ACTION]
sum_reward_episode += float(reward)
episode_total_steps += 1
@@ -329,13 +327,13 @@ def act_with_policy(
# Create transition for learner (convert to old format)
list_transition_to_send_to_learner.append(
Transition(
state=transition[TransitionKey.OBSERVATION],
action=processed_action,
state=observation,
action=executed_action,
reward=reward,
next_state=new_transition[TransitionKey.OBSERVATION],
next_state=next_observation,
done=done,
truncated=truncated,
complementary_info=new_transition[TransitionKey.COMPLEMENTARY_DATA],
complementary_info={}, # new_transition[TransitionKey.COMPLEMENTARY_DATA],
)
)
+27 -22
View File
@@ -800,11 +800,15 @@ def make_processors(env, cfg):
crop_params_dict=cfg.processor.crop_params_dict, resize_size=cfg.processor.resize_size
),
TimeLimitProcessor(max_episode_steps=int(cfg.processor.control_time_s * cfg.fps)),
GripperPenaltyProcessor(
penalty=cfg.processor.gripper_penalty, max_gripper_pos=cfg.processor.max_gripper_pos
),
DeviceProcessor(device=cfg.device),
]
if cfg.processor.use_gripper:
env_pipeline_steps.append(
GripperPenaltyProcessor(
penalty=cfg.processor.gripper_penalty, max_gripper_pos=cfg.processor.max_gripper_pos
)
)
env_pipeline_steps.append(DeviceProcessor(device=cfg.device))
env_processor = RobotProcessor(steps=env_pipeline_steps)
action_pipeline_steps = [
@@ -920,12 +924,13 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
"action": action_features,
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
"complementary_info.discrete_penalty": {
}
if cfg.processor.use_gripper:
features["complementary_info.discrete_penalty"] = {
"dtype": "float32",
"shape": (1,),
"names": ["discrete_penalty"],
},
}
}
for key, value in transition[TransitionKey.OBSERVATION].items():
if key == "observation.state":
@@ -977,16 +982,17 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
truncated = transition.get(TransitionKey.TRUNCATED, False)
if cfg.mode == "record":
observations = {k: v.squeeze(0) for k, v in transition[TransitionKey.OBSERVATION].items()}
observations = {k: v.squeeze(0).cpu() for k, v in transition[TransitionKey.OBSERVATION].items()}
frame = {
**observations,
"action": transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"],
"action": transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"].cpu(),
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
"next.done": np.array([terminated or truncated], dtype=bool),
"complementary_info.discrete_penalty": np.array(
[transition[TransitionKey.COMPLEMENTARY_DATA]["discrete_penalty"]], dtype=np.float32
),
}
if cfg.processor.use_gripper:
frame["complementary_info.discrete_penalty"] = np.array(
[transition[TransitionKey.COMPLEMENTARY_DATA]["discrete_penalty"]], dtype=np.float32
)
dataset.add_frame(frame, task=cfg.task)
episode_step += 1
@@ -997,16 +1003,6 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
logging.info(
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
)
# Reset for new episode
obs, info = env.reset()
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
env_processor.reset()
action_processor.reset()
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(transition)
episode_step = 0
episode_idx += 1
@@ -1019,6 +1015,15 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
logging.info(f"Saving episode {episode_idx}")
dataset.save_episode()
# Reset for new episode
obs, info = env.reset()
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
env_processor.reset()
action_processor.reset()
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(transition)
# Maintain fps timing
busy_wait(dt - (time.perf_counter() - step_start_time))