Hardcoded some normalization parameters. TODO refactor

Added masking actions on the level of the intervention actions and offline dataset

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-13 14:27:14 +01:00
committed by AdilZouitine
parent a0e0a9a9b1
commit eb7e28d9d9
6 changed files with 36 additions and 8 deletions
+7 -2
View File
@@ -354,7 +354,7 @@ def add_actor_information_and_train(
transition = move_transition_to_device(transition, device=device)
replay_buffer.add(**transition)
if transition.get("complementary_info", {}).get("is_interaction"):
if transition.get("complementary_info", {}).get("is_intervention"):
offline_replay_buffer.add(**transition)
while not interaction_message_queue.empty():
@@ -568,6 +568,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy_lock = Lock()
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
@@ -593,8 +594,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer