mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 06:29:47 +00:00
chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -17,19 +17,21 @@ from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import HF_HOME
|
||||
|
||||
OBS_ENV_STATE = "observation.environment_state"
|
||||
OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
OBS_LANGUAGE = "observation.language"
|
||||
OBS_STR = "observation"
|
||||
OBS_PREFIX = OBS_STR + "."
|
||||
OBS_ENV_STATE = OBS_STR + ".environment_state"
|
||||
OBS_STATE = OBS_STR + ".state"
|
||||
OBS_IMAGE = OBS_STR + ".image"
|
||||
OBS_IMAGES = OBS_IMAGE + "s"
|
||||
OBS_LANGUAGE = OBS_STR + ".language"
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
|
||||
ACTION = "action"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
DONE = "next.done"
|
||||
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
|
||||
ROBOTS = "robots"
|
||||
ROBOT_TYPE = "robot_type"
|
||||
TELEOPERATORS = "teleoperators"
|
||||
|
||||
@@ -19,6 +19,8 @@ from typing import Any
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from .constants import OBS_PREFIX, OBS_STR
|
||||
|
||||
|
||||
def init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
"""Initializes the Rerun SDK for visualizing the control loop."""
|
||||
@@ -63,7 +65,7 @@ def log_rerun_data(
|
||||
for k, v in observation.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("observation.") else f"observation.{k}"
|
||||
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
|
||||
Reference in New Issue
Block a user