final refactor/fix

This commit is contained in:
Jade Choghari (jchoghar)
2025-08-25 06:25:02 -04:00
parent afad90ffaa
commit 8d2c66abd2
7 changed files with 47 additions and 75 deletions
+1 -1
View File
@@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
# config # config
REPO_ID=yzembodied/libero_10_image_task_1 REPO_ID=jadechoghari/smol-libero
TASK=libero_10 TASK=libero_10
OUTPUT_DIR=./outputs/ OUTPUT_DIR=./outputs/
@@ -2,14 +2,12 @@
unset LEROBOT_HOME unset LEROBOT_HOME
unset HF_LEROBOT_HOME unset HF_LEROBOT_HOME
# === CONFIGURATION === # CONFIGURATION
POLICY_PATH="ganatrask/lerobot-pi0-libero-object" # or outputs/train/.../pretrained_model POLICY_PATH="ganatrask/lerobot-pi0-libero-object"
TASK=libero_object TASK=libero_object
ENV_TYPE="libero" ENV_TYPE="libero"
BATCH_SIZE=1 BATCH_SIZE=1
N_EPISODES=1 N_EPISODES=1
USE_AMP=false
DEVICE=cuda
# RUN EVALUATION # RUN EVALUATION
python src/lerobot/scripts/eval.py \ python src/lerobot/scripts/eval.py \
+2 -2
View File
@@ -295,8 +295,8 @@ class LiberoEnv(EnvConfig):
default_factory=lambda: { default_factory=lambda: {
"action": ACTION, "action": ACTION,
"agent_pos": OBS_STATE, "agent_pos": OBS_STATE,
"pixels/agentview_image": f"{OBS_IMAGE}", "pixels/agentview_image": f"{OBS_IMAGES}.image",
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGE_2}", "pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
} }
) )
+2 -2
View File
@@ -41,12 +41,12 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
Args: Args:
cfg (EnvConfig): the config of the environment to instantiate. cfg (EnvConfig): the config of the environment to instantiate.
n_envs (int, optional): The number of parallelized env to return. Defaults to 1. n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
use_async_envs (bool, optional): Wether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
False. False.
Raises: Raises:
ValueError: if n_envs < 1 ValueError: if n_envs < 1
ModuleNotFoundError: If the requested env package is not intalled ModuleNotFoundError: If the requested env package is not installed
Returns: Returns:
gym.vector.VectorEnv: The parallelized gym.env instance. gym.vector.VectorEnv: The parallelized gym.env instance.
+30 -36
View File
@@ -26,65 +26,59 @@ from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig from lerobot.envs.configs import EnvConfig
from lerobot.utils.utils import get_channel_first_image_shape from lerobot.utils.utils import get_channel_first_image_shape
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
def preprocess_observation( # TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
observations: dict[str, np.ndarray], cfg: dict[str, Any] = None
) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation. """Convert environment observation to LeRobot format observation.
Args: Args:
observations: Dictionary of observation batches from a Gym vector environment. observation: Dictionary of observation batches from a Gym vector environment.
cfg: Policy config containing expected feature keys.
Returns: Returns:
Dictionary of observation batches with keys renamed to match policy expectations. Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
""" """
# map to expected inputs for the policy
return_observations = {} return_observations = {}
# expected keys from policy
policy_img_keys = list(cfg.image_features.keys()) if cfg else ["observation.image"]
state_key = cfg.robot_state_feature_key if cfg else "observation.state"
# handle images
if "pixels" in observations: if "pixels" in observations:
if isinstance(observations["pixels"], dict): if isinstance(observations["pixels"], dict):
env_img_keys = list(observations["pixels"].keys()) imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
imgs = observations["pixels"]
else: else:
env_img_keys = ["pixels"] imgs = {"observation.image": observations["pixels"]}
imgs = {"pixels": observations["pixels"]}
# build rename map env_key -> policy_key
rename_map = dict(zip(env_img_keys, policy_img_keys, strict=False))
for imgkey, img in imgs.items(): for imgkey, img in imgs.items():
target_key = rename_map.get(imgkey, imgkey) # TODO(aliberts, rcadene): use transforms.ToTensor()?
img = torch.from_numpy(img) img = torch.from_numpy(img)
# sanity checks # When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
# This is the case for human-in-the-loop RL where there is only one environment.
if img.ndim == 3:
img = img.unsqueeze(0)
# sanity check that images are channel last
_, h, w, c = img.shape _, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, got {img.shape=}" assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
assert img.dtype == torch.uint8, f"expect torch.uint8, got {img.dtype=}"
# channel last → channel first, normalize # sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous() img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.float() / 255.0 img = img.type(torch.float32)
img /= 255
return_observations[target_key] = img return_observations[imgkey] = img
# handle state
if "environment_state" in observations: if "environment_state" in observations:
return_observations["observation.environment_state"] = torch.from_numpy( env_state = torch.from_numpy(observations["environment_state"]).float()
observations["environment_state"] if env_state.dim() == 1:
).float() env_state = env_state.unsqueeze(0)
return_observations[state_key] = torch.from_numpy(observations["agent_pos"]).float() return_observations["observation.environment_state"] = env_state
if "task" in observations: # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
return_observations["task"] = observations["task"] agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations["observation.state"] = agent_pos
return return_observations return return_observations
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to also refactor preprocess_observation and externalize normalization from policies) # (need to also refactor preprocess_observation and externalize normalization from policies)
+7 -24
View File
@@ -62,6 +62,7 @@ import einops
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
import torch import torch
from termcolor import colored
from torch import Tensor, nn from torch import Tensor, nn
from tqdm import trange from tqdm import trange
@@ -73,6 +74,7 @@ from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters from lerobot.policies.utils import get_device_from_parameters
from lerobot.utils.io_utils import write_video from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import ( from lerobot.utils.utils import (
get_safe_torch_device, get_safe_torch_device,
init_logging, init_logging,
@@ -146,8 +148,7 @@ def rollout(
check_env_attributes_and_types(env) check_env_attributes_and_types(env)
while not np.all(done) and step < max_steps: while not np.all(done) and step < max_steps:
# Numpy array to tensor and changing dictionary keys to LeRobot policy format. # Numpy array to tensor and changing dictionary keys to LeRobot policy format.
# observation = preprocess_observation(observation) observation = preprocess_observation(observation)
observation = preprocess_observation(observation, cfg=policy.config)
if return_observations: if return_observations:
all_observations.append(deepcopy(observation)) all_observations.append(deepcopy(observation))
@@ -459,24 +460,8 @@ def _compile_episode_data(
return data_dict return data_dict
def set_global_seed(seed):
"""Set seed for reproducibility."""
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def log_output_dir(out_dir):
logging.info("Output dir:" + f" {out_dir}")
@parser.wrap() @parser.wrap()
def eval(cfg: EvalPipelineConfig): def eval_main(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg))) logging.info(pformat(asdict(cfg)))
# Check device is available # Check device is available
@@ -484,9 +469,9 @@ def eval(cfg: EvalPipelineConfig):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed) set_seed(cfg.seed)
log_output_dir(cfg.output_dir) logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info("Making environment.") logging.info("Making environment.")
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
@@ -494,11 +479,9 @@ def eval(cfg: EvalPipelineConfig):
logging.info("Making policy.") logging.info("Making policy.")
policy = make_policy( policy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
# device=device,
env_cfg=cfg.env, env_cfg=cfg.env,
) )
policy.eval() policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
if cfg.env.multitask_eval: if cfg.env.multitask_eval:
info = eval_policy_multitask( info = eval_policy_multitask(
@@ -663,4 +646,4 @@ def eval_policy_multitask(
if __name__ == "__main__": if __name__ == "__main__":
init_logging() init_logging()
eval() eval_main()
+3 -6
View File
@@ -186,7 +186,6 @@ def train(cfg: TrainPipelineConfig):
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train() policy.train()
train_metrics = { train_metrics = {
"loss": AverageMeter("loss", ":.3f"), "loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"), "grad_norm": AverageMeter("grdn", ":.3f"),
@@ -263,15 +262,14 @@ def train(cfg: TrainPipelineConfig):
max_parallel_tasks=cfg.env.max_parallel_tasks, max_parallel_tasks=cfg.env.max_parallel_tasks,
) )
aggregated = eval_info["overall"]["aggregated"] aggregated = eval_info["overall"]["aggregated"]
# Print per-suite stats # Print per-suite stats, log?
for task_group, task_group_info in eval_info.items(): for task_group, task_group_info in eval_info.items():
if task_group == "overall": if task_group == "overall":
continue # Skip the overall stats since we already printed it continue # Skip the overall stats since we already printed it
print(f"\nAggregated Metrics for {task_group}:") print(f"\nAggregated Metrics for {task_group}:")
print(task_group_info["aggregated"]) print(task_group_info["aggregated"])
breakpoint() breakpoint()
else: else:
print("START EVAL")
eval_info = eval_policy( eval_info = eval_policy(
eval_env, eval_env,
policy, policy,
@@ -280,9 +278,8 @@ def train(cfg: TrainPipelineConfig):
max_episodes_rendered=4, max_episodes_rendered=4,
start_seed=cfg.seed, start_seed=cfg.seed,
) )
breakpoint()
aggregated = eval_info["aggregated"] aggregated = eval_info["aggregated"]
print("END EVAL") breakpoint()
eval_metrics = { eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),