mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
chore: replace hard-coded next values with constants throughout all the source code (#2056)
This commit is contained in:
@@ -24,7 +24,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, REWARD
|
||||
from lerobot.utils.transition import Transition
|
||||
|
||||
|
||||
@@ -534,8 +534,8 @@ class ReplayBuffer:
|
||||
features[ACTION] = act_info
|
||||
|
||||
# Add "reward" and "done"
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,)}
|
||||
features[REWARD] = {"dtype": "float32", "shape": (1,)}
|
||||
features[DONE] = {"dtype": "bool", "shape": (1,)}
|
||||
|
||||
# Add state keys
|
||||
for key in self.states:
|
||||
@@ -578,8 +578,8 @@ class ReplayBuffer:
|
||||
|
||||
# Fill action, reward, done
|
||||
frame_dict[ACTION] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict[DONE] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
frame_dict["task"] = task_name
|
||||
|
||||
# Add complementary_info if available
|
||||
@@ -648,7 +648,7 @@ class ReplayBuffer:
|
||||
|
||||
# Check if the dataset has "next.done" key
|
||||
sample = dataset[0]
|
||||
has_done_key = "next.done" in sample
|
||||
has_done_key = DONE in sample
|
||||
|
||||
# Check for complementary_info keys
|
||||
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
|
||||
@@ -671,11 +671,11 @@ class ReplayBuffer:
|
||||
action = current_sample[ACTION].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
reward = float(current_sample[REWARD].item()) # ensure float
|
||||
|
||||
# Determine done flag - use next.done if available, otherwise infer from episode boundaries
|
||||
if has_done_key:
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
done = bool(current_sample[DONE].item()) # ensure bool
|
||||
else:
|
||||
# If this is the last frame or if next frame is in a different episode, mark as done
|
||||
done = False
|
||||
|
||||
@@ -25,6 +25,7 @@ import torchvision.transforms.functional as F # type: ignore # noqa: N812
|
||||
from tqdm import tqdm # type: ignore
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import DONE, REWARD
|
||||
|
||||
|
||||
def select_rect_roi(img):
|
||||
@@ -212,7 +213,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
for key, value in frame.items():
|
||||
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"):
|
||||
continue
|
||||
if key in ("next.done", "next.reward"):
|
||||
if key in (DONE, REWARD):
|
||||
# if not isinstance(value, str) and len(value.shape) == 0:
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ from lerobot.teleoperators import (
|
||||
)
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -602,8 +602,8 @@ def control_loop(
|
||||
action_features = teleop_device.action_features
|
||||
features = {
|
||||
ACTION: action_features,
|
||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
REWARD: {"dtype": "float32", "shape": (1,), "names": None},
|
||||
DONE: {"dtype": "bool", "shape": (1,), "names": None},
|
||||
}
|
||||
if use_gripper:
|
||||
features["complementary_info.discrete_penalty"] = {
|
||||
@@ -673,8 +673,8 @@ def control_loop(
|
||||
frame = {
|
||||
**observations,
|
||||
ACTION: action_to_record.cpu(),
|
||||
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||
"next.done": np.array([terminated or truncated], dtype=bool),
|
||||
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||
DONE: np.array([terminated or truncated], dtype=bool),
|
||||
}
|
||||
if use_gripper:
|
||||
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)
|
||||
|
||||
Reference in New Issue
Block a user