mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
Hardcoded some normalization parameters. TODO refactor
Added masking actions on the level of the intervention actions and offline dataset Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
committed by
AdilZouitine
parent
a0e0a9a9b1
commit
eb7e28d9d9
@@ -354,7 +354,7 @@ def add_actor_information_and_train(
|
||||
transition = move_transition_to_device(transition, device=device)
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if transition.get("complementary_info", {}).get("is_interaction"):
|
||||
if transition.get("complementary_info", {}).get("is_intervention"):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
while not interaction_message_queue.empty():
|
||||
@@ -568,6 +568,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
|
||||
policy_lock = Lock()
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
@@ -593,8 +594,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user