mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
Add end effector action space to hil-serl (#861)
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
committed by
AdilZouitine
parent
7960f2c3c1
commit
b82faf7d8c
@@ -315,16 +315,49 @@ def start_learner_server(
|
||||
|
||||
|
||||
def check_nan_in_transition(
|
||||
observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor
|
||||
):
|
||||
for k in observations:
|
||||
if torch.isnan(observations[k]).any():
|
||||
logging.error(f"observations[{k}] contains NaN values")
|
||||
for k in next_state:
|
||||
if torch.isnan(next_state[k]).any():
|
||||
logging.error(f"next_state[{k}] contains NaN values")
|
||||
observations: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
next_state: torch.Tensor,
|
||||
raise_error: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Check for NaN values in transition data.
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observation tensors
|
||||
actions: Action tensor
|
||||
next_state: Dictionary of next state tensors
|
||||
raise_error: If True, raises ValueError when NaN is detected
|
||||
|
||||
Returns:
|
||||
bool: True if NaN values were detected, False otherwise
|
||||
"""
|
||||
nan_detected = False
|
||||
|
||||
# Check observations
|
||||
for key, tensor in observations.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(f"observations[{key}] contains NaN values")
|
||||
nan_detected = True
|
||||
if raise_error:
|
||||
raise ValueError(f"NaN detected in observations[{key}]")
|
||||
|
||||
# Check next state
|
||||
for key, tensor in next_state.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(f"next_state[{key}] contains NaN values")
|
||||
nan_detected = True
|
||||
if raise_error:
|
||||
raise ValueError(f"NaN detected in next_state[{key}]")
|
||||
|
||||
# Check actions
|
||||
if torch.isnan(actions).any():
|
||||
logging.error("actions contains NaN values")
|
||||
nan_detected = True
|
||||
if raise_error:
|
||||
raise ValueError("NaN detected in actions")
|
||||
|
||||
return nan_detected
|
||||
|
||||
|
||||
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
@@ -460,9 +493,18 @@ def add_actor_information_and_train(
|
||||
|
||||
for transition in transition_list:
|
||||
transition = move_transition_to_device(transition, device=device)
|
||||
if check_nan_in_transition(
|
||||
transition["state"], transition["action"], transition["next_state"]
|
||||
):
|
||||
logging.warning("NaN detected in transition, skipping")
|
||||
continue
|
||||
replay_buffer.add(**transition)
|
||||
if transition.get("complementary_info", {}).get("is_intervention"):
|
||||
|
||||
if cfg.dataset_repo_id is not None and transition.get(
|
||||
"complementary_info", {}
|
||||
).get("is_intervention"):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
logging.debug("[LEARNER] Received transitions")
|
||||
logging.debug("[LEARNER] Waiting for interactions")
|
||||
while not interaction_message_queue.empty() and not shutdown_event.is_set():
|
||||
|
||||
Reference in New Issue
Block a user