mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -24,6 +24,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.transition import Transition
|
||||
|
||||
|
||||
@@ -240,7 +241,7 @@ class ReplayBuffer:
|
||||
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
||||
|
||||
# Identify image keys that need augmentation
|
||||
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
|
||||
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
|
||||
|
||||
# Create batched state and next_state
|
||||
batch_state = {}
|
||||
|
||||
@@ -73,6 +73,7 @@ from lerobot.teleoperators import (
|
||||
)
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -180,7 +181,7 @@ class RobotEnv(gym.Env):
|
||||
|
||||
# Define observation spaces for images and other states.
|
||||
if current_observation is not None and "pixels" in current_observation:
|
||||
prefix = "observation.images"
|
||||
prefix = OBS_IMAGES
|
||||
observation_spaces = {
|
||||
f"{prefix}.{key}": gym.spaces.Box(
|
||||
low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8
|
||||
@@ -190,7 +191,7 @@ class RobotEnv(gym.Env):
|
||||
|
||||
if current_observation is not None:
|
||||
agent_pos = current_observation["agent_pos"]
|
||||
observation_spaces["observation.state"] = gym.spaces.Box(
|
||||
observation_spaces[OBS_STATE] = gym.spaces.Box(
|
||||
low=0,
|
||||
high=10,
|
||||
shape=agent_pos.shape,
|
||||
@@ -612,7 +613,7 @@ def control_loop(
|
||||
}
|
||||
|
||||
for key, value in transition[TransitionKey.OBSERVATION].items():
|
||||
if key == "observation.state":
|
||||
if key == OBS_STATE:
|
||||
features[key] = {
|
||||
"dtype": "float32",
|
||||
"shape": value.squeeze(0).shape,
|
||||
|
||||
Reference in New Issue
Block a user