mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
Update ManiSkill configuration and replay buffer to support truncation and dataset handling
- Reduced image size in ManiSkill environment configuration from 128 to 64 - Added support for truncation in replay buffer and actor server - Updated SAC policy configuration to use a specific dataset and modify vision encoder settings - Improved dataset conversion process with progress tracking and task naming - Added flexibility for joint action space masking in learner server
This commit is contained in:
@@ -153,7 +153,7 @@ def initialize_replay_buffer(
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
storage_device=device
|
||||
storage_device=device,
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
@@ -169,8 +169,13 @@ def initialize_replay_buffer(
|
||||
)
|
||||
|
||||
|
||||
def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
||||
def get_observation_features(
|
||||
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
if (
|
||||
policy.config.vision_encoder_name is None
|
||||
or not policy.config.freeze_vision_encoder
|
||||
):
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -338,6 +343,7 @@ def add_actor_information_and_train(
|
||||
interaction_step_shift = (
|
||||
resume_interaction_step if resume_interaction_step is not None else 0
|
||||
)
|
||||
saved_data = False
|
||||
while True:
|
||||
if shutdown_event is not None and shutdown_event.is_set():
|
||||
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
||||
@@ -372,7 +378,6 @@ def add_actor_information_and_train(
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
@@ -382,7 +387,9 @@ def add_actor_information_and_train(
|
||||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
@@ -415,7 +422,9 @@ def add_actor_information_and_train(
|
||||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
@@ -436,8 +445,10 @@ def add_actor_information_and_train(
|
||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
with policy_lock:
|
||||
loss_actor = policy.compute_loss_actor(observations=observations,
|
||||
observation_features=observation_features)
|
||||
loss_actor = policy.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
@@ -447,7 +458,7 @@ def add_actor_information_and_train(
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features
|
||||
observation_features=observation_features,
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
@@ -458,7 +469,9 @@ def add_actor_information_and_train(
|
||||
policy.update_target_networks()
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
training_infos["Optimization step"] = optimization_step
|
||||
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
||||
logger.log_dict(
|
||||
d=training_infos, mode="train", custom_step_key="Optimization step"
|
||||
)
|
||||
# logging.info(f"Training infos: {training_infos}")
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
@@ -621,11 +634,13 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = [
|
||||
i
|
||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
]
|
||||
active_action_dims = None
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
i
|
||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
|
||||
Reference in New Issue
Block a user