mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
chore: replace hard-coded action values with constants throughout all the source code (#2055)
* chore: replace hard-coded 'action' values with constants throughout all the source code * chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
@@ -24,7 +24,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE
|
||||
from lerobot.utils.transition import Transition
|
||||
|
||||
|
||||
@@ -467,7 +467,7 @@ class ReplayBuffer:
|
||||
if list_transition:
|
||||
first_transition = list_transition[0]
|
||||
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
||||
first_action = first_transition["action"].to(device)
|
||||
first_action = first_transition[ACTION].to(device)
|
||||
|
||||
# Get complementary info if available
|
||||
first_complementary_info = None
|
||||
@@ -492,7 +492,7 @@ class ReplayBuffer:
|
||||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(storage_device)
|
||||
|
||||
action = data["action"]
|
||||
action = data[ACTION]
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
@@ -530,8 +530,8 @@ class ReplayBuffer:
|
||||
|
||||
# Add "action"
|
||||
sample_action = self.actions[0]
|
||||
act_info = guess_feature_info(t=sample_action, name="action")
|
||||
features["action"] = act_info
|
||||
act_info = guess_feature_info(t=sample_action, name=ACTION)
|
||||
features[ACTION] = act_info
|
||||
|
||||
# Add "reward" and "done"
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
||||
@@ -577,7 +577,7 @@ class ReplayBuffer:
|
||||
frame_dict[key] = self.states[key][actual_idx].cpu()
|
||||
|
||||
# Fill action, reward, done
|
||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||
frame_dict[ACTION] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
frame_dict["task"] = task_name
|
||||
@@ -668,7 +668,7 @@ class ReplayBuffer:
|
||||
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 2) Action -----
|
||||
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||
action = current_sample[ACTION].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
@@ -788,8 +788,8 @@ def concatenate_batch_transitions(
|
||||
}
|
||||
|
||||
# Concatenate basic fields
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
left_batch_transitions[ACTION] = torch.cat(
|
||||
[left_batch_transitions[ACTION], right_batch_transition[ACTION]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
|
||||
@@ -73,7 +73,7 @@ from lerobot.teleoperators import (
|
||||
)
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -601,7 +601,7 @@ def control_loop(
|
||||
if cfg.mode == "record":
|
||||
action_features = teleop_device.action_features
|
||||
features = {
|
||||
"action": action_features,
|
||||
ACTION: action_features,
|
||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
}
|
||||
@@ -672,7 +672,7 @@ def control_loop(
|
||||
)
|
||||
frame = {
|
||||
**observations,
|
||||
"action": action_to_record.cpu(),
|
||||
ACTION: action_to_record.cpu(),
|
||||
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||
"next.done": np.array([terminated or truncated], dtype=bool),
|
||||
}
|
||||
@@ -733,7 +733,7 @@ def replay_trajectory(
|
||||
download_videos=False,
|
||||
)
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode)
|
||||
actions = episode_frames.select_columns("action")
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
_, info = env.reset()
|
||||
|
||||
@@ -741,7 +741,7 @@ def replay_trajectory(
|
||||
start_time = time.perf_counter()
|
||||
transition = create_transition(
|
||||
observation=env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {},
|
||||
action=action_data["action"],
|
||||
action=action_data[ACTION],
|
||||
)
|
||||
transition = action_processor(transition)
|
||||
env.step(transition[TransitionKey.ACTION])
|
||||
|
||||
@@ -80,6 +80,7 @@ from lerobot.transport.utils import (
|
||||
state_to_bytes,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
@@ -402,7 +403,7 @@ def add_actor_information_and_train(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
@@ -415,7 +416,7 @@ def add_actor_information_and_train(
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
"action": actions,
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
@@ -460,7 +461,7 @@ def add_actor_information_and_train(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
@@ -474,7 +475,7 @@ def add_actor_information_and_train(
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
"action": actions,
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
@@ -1155,7 +1156,7 @@ def process_transitions(
|
||||
# Skip transitions with NaN values
|
||||
if check_nan_in_transition(
|
||||
observations=transition["state"],
|
||||
actions=transition["action"],
|
||||
actions=transition[ACTION],
|
||||
next_state=transition["next_state"],
|
||||
):
|
||||
logging.warning("[LEARNER] NaN detected in transition, skipping")
|
||||
|
||||
Reference in New Issue
Block a user