Update ManiSkill configuration and replay buffer to support truncation and dataset handling

- Reduced image size in ManiSkill environment configuration from 128 to 64
- Added support for truncation in replay buffer and actor server
- Updated SAC policy configuration to use a specific dataset and modify vision encoder settings
- Improved dataset conversion process with progress tracking and task naming
- Added flexibility for joint action space masking in learner server
This commit is contained in:
AdilZouitine
2025-02-24 16:53:37 +00:00
parent ff223c106d
commit 2c799508d7
5 changed files with 78 additions and 27 deletions
+30 -15
View File
@@ -153,7 +153,7 @@ def initialize_replay_buffer(
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
storage_device=device
storage_device=device,
)
dataset = LeRobotDataset(
@@ -169,8 +169,13 @@ def initialize_replay_buffer(
)
def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if (
policy.config.vision_encoder_name is None
or not policy.config.freeze_vision_encoder
):
return None, None
with torch.no_grad():
@@ -338,6 +343,7 @@ def add_actor_information_and_train(
interaction_step_shift = (
resume_interaction_step if resume_interaction_step is not None else 0
)
saved_data = False
while True:
if shutdown_event is not None and shutdown_event.is_set():
logging.info("[LEARNER] Shutdown signal received. Exiting...")
@@ -372,7 +378,6 @@ def add_actor_information_and_train(
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
@@ -382,7 +387,9 @@ def add_actor_information_and_train(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -415,7 +422,9 @@ def add_actor_information_and_train(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -436,8 +445,10 @@ def add_actor_information_and_train(
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(observations=observations,
observation_features=observation_features)
loss_actor = policy.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
optimizers["actor"].zero_grad()
loss_actor.backward()
@@ -447,7 +458,7 @@ def add_actor_information_and_train(
loss_temperature = policy.compute_loss_temperature(
observations=observations,
observation_features=observation_features
observation_features=observation_features,
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
@@ -458,7 +469,9 @@ def add_actor_information_and_train(
policy.update_target_networks()
if optimization_step % cfg.training.log_freq == 0:
training_infos["Optimization step"] = optimization_step
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
logger.log_dict(
d=training_infos, mode="train", custom_step_key="Optimization step"
)
# logging.info(f"Training infos: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
@@ -621,11 +634,13 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,