chore: replace hard-coded action values with constants throughout all the source code (#2055)

* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-26 13:33:18 +02:00
committed by GitHub
parent 9627765ce2
commit d2782cf66b
47 changed files with 269 additions and 255 deletions
+9 -9
View File
@@ -24,7 +24,7 @@ import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import OBS_IMAGE
from lerobot.utils.constants import ACTION, OBS_IMAGE
from lerobot.utils.transition import Transition
@@ -467,7 +467,7 @@ class ReplayBuffer:
if list_transition:
first_transition = list_transition[0]
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
first_action = first_transition["action"].to(device)
first_action = first_transition[ACTION].to(device)
# Get complementary info if available
first_complementary_info = None
@@ -492,7 +492,7 @@ class ReplayBuffer:
elif isinstance(v, torch.Tensor):
data[k] = v.to(storage_device)
action = data["action"]
action = data[ACTION]
replay_buffer.add(
state=data["state"],
@@ -530,8 +530,8 @@ class ReplayBuffer:
# Add "action"
sample_action = self.actions[0]
act_info = guess_feature_info(t=sample_action, name="action")
features["action"] = act_info
act_info = guess_feature_info(t=sample_action, name=ACTION)
features[ACTION] = act_info
# Add "reward" and "done"
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
@@ -577,7 +577,7 @@ class ReplayBuffer:
frame_dict[key] = self.states[key][actual_idx].cpu()
# Fill action, reward, done
frame_dict["action"] = self.actions[actual_idx].cpu()
frame_dict[ACTION] = self.actions[actual_idx].cpu()
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
frame_dict["task"] = task_name
@@ -668,7 +668,7 @@ class ReplayBuffer:
current_state[key] = val.unsqueeze(0) # Add batch dimension
# ----- 2) Action -----
action = current_sample["action"].unsqueeze(0) # Add batch dimension
action = current_sample[ACTION].unsqueeze(0) # Add batch dimension
# ----- 3) Reward and done -----
reward = float(current_sample["next.reward"].item()) # ensure float
@@ -788,8 +788,8 @@ def concatenate_batch_transitions(
}
# Concatenate basic fields
left_batch_transitions["action"] = torch.cat(
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
left_batch_transitions[ACTION] = torch.cat(
[left_batch_transitions[ACTION], right_batch_transition[ACTION]], dim=0
)
left_batch_transitions["reward"] = torch.cat(
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
+5 -5
View File
@@ -73,7 +73,7 @@ from lerobot.teleoperators import (
)
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
@@ -601,7 +601,7 @@ def control_loop(
if cfg.mode == "record":
action_features = teleop_device.action_features
features = {
"action": action_features,
ACTION: action_features,
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
}
@@ -672,7 +672,7 @@ def control_loop(
)
frame = {
**observations,
"action": action_to_record.cpu(),
ACTION: action_to_record.cpu(),
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
"next.done": np.array([terminated or truncated], dtype=bool),
}
@@ -733,7 +733,7 @@ def replay_trajectory(
download_videos=False,
)
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode)
actions = episode_frames.select_columns("action")
actions = episode_frames.select_columns(ACTION)
_, info = env.reset()
@@ -741,7 +741,7 @@ def replay_trajectory(
start_time = time.perf_counter()
transition = create_transition(
observation=env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {},
action=action_data["action"],
action=action_data[ACTION],
)
transition = action_processor(transition)
env.step(transition[TransitionKey.ACTION])
+6 -5
View File
@@ -80,6 +80,7 @@ from lerobot.transport.utils import (
state_to_bytes,
)
from lerobot.utils.constants import (
ACTION,
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
@@ -402,7 +403,7 @@ def add_actor_information_and_train(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
@@ -415,7 +416,7 @@ def add_actor_information_and_train(
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
"action": actions,
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
@@ -460,7 +461,7 @@ def add_actor_information_and_train(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
@@ -474,7 +475,7 @@ def add_actor_information_and_train(
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
"action": actions,
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
@@ -1155,7 +1156,7 @@ def process_transitions(
# Skip transitions with NaN values
if check_nan_in_transition(
observations=transition["state"],
actions=transition["action"],
actions=transition[ACTION],
next_state=transition["next_state"],
):
logging.warning("[LEARNER] NaN detected in transition, skipping")