chore: replace hard-coded obs values with constants throughout all the source code (#2037)

* chore: replace hard-coded OBS values with constants throughout all the source code

* chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-25 15:36:47 +02:00
committed by GitHub
parent ddba994d73
commit 43d878a102
52 changed files with 659 additions and 649 deletions
+2 -1
View File
@@ -24,6 +24,7 @@ import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import OBS_IMAGE
from lerobot.utils.transition import Transition
@@ -240,7 +241,7 @@ class ReplayBuffer:
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
# Identify image keys that need augmentation
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
# Create batched state and next_state
batch_state = {}
+4 -3
View File
@@ -73,6 +73,7 @@ from lerobot.teleoperators import (
)
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
@@ -180,7 +181,7 @@ class RobotEnv(gym.Env):
# Define observation spaces for images and other states.
if current_observation is not None and "pixels" in current_observation:
prefix = "observation.images"
prefix = OBS_IMAGES
observation_spaces = {
f"{prefix}.{key}": gym.spaces.Box(
low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8
@@ -190,7 +191,7 @@ class RobotEnv(gym.Env):
if current_observation is not None:
agent_pos = current_observation["agent_pos"]
observation_spaces["observation.state"] = gym.spaces.Box(
observation_spaces[OBS_STATE] = gym.spaces.Box(
low=0,
high=10,
shape=agent_pos.shape,
@@ -612,7 +613,7 @@ def control_loop(
}
for key, value in transition[TransitionKey.OBSERVATION].items():
if key == "observation.state":
if key == OBS_STATE:
features[key] = {
"dtype": "float32",
"shape": value.squeeze(0).shape,