mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
final refactor/fix
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# config
|
# config
|
||||||
REPO_ID=yzembodied/libero_10_image_task_1
|
REPO_ID=jadechoghari/smol-libero
|
||||||
TASK=libero_10
|
TASK=libero_10
|
||||||
OUTPUT_DIR=./outputs/
|
OUTPUT_DIR=./outputs/
|
||||||
|
|
||||||
|
|||||||
@@ -2,14 +2,12 @@
|
|||||||
|
|
||||||
unset LEROBOT_HOME
|
unset LEROBOT_HOME
|
||||||
unset HF_LEROBOT_HOME
|
unset HF_LEROBOT_HOME
|
||||||
# === CONFIGURATION ===
|
# CONFIGURATION
|
||||||
POLICY_PATH="ganatrask/lerobot-pi0-libero-object" # or outputs/train/.../pretrained_model
|
POLICY_PATH="ganatrask/lerobot-pi0-libero-object"
|
||||||
TASK=libero_object
|
TASK=libero_object
|
||||||
ENV_TYPE="libero"
|
ENV_TYPE="libero"
|
||||||
BATCH_SIZE=1
|
BATCH_SIZE=1
|
||||||
N_EPISODES=1
|
N_EPISODES=1
|
||||||
USE_AMP=false
|
|
||||||
DEVICE=cuda
|
|
||||||
|
|
||||||
# RUN EVALUATION
|
# RUN EVALUATION
|
||||||
python src/lerobot/scripts/eval.py \
|
python src/lerobot/scripts/eval.py \
|
||||||
@@ -295,8 +295,8 @@ class LiberoEnv(EnvConfig):
|
|||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": ACTION,
|
"action": ACTION,
|
||||||
"agent_pos": OBS_STATE,
|
"agent_pos": OBS_STATE,
|
||||||
"pixels/agentview_image": f"{OBS_IMAGE}",
|
"pixels/agentview_image": f"{OBS_IMAGES}.image",
|
||||||
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGE_2}",
|
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -41,12 +41,12 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
|||||||
Args:
|
Args:
|
||||||
cfg (EnvConfig): the config of the environment to instantiate.
|
cfg (EnvConfig): the config of the environment to instantiate.
|
||||||
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
|
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
|
||||||
use_async_envs (bool, optional): Wether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
|
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
|
||||||
False.
|
False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if n_envs < 1
|
ValueError: if n_envs < 1
|
||||||
ModuleNotFoundError: If the requested env package is not intalled
|
ModuleNotFoundError: If the requested env package is not installed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
gym.vector.VectorEnv: The parallelized gym.env instance.
|
gym.vector.VectorEnv: The parallelized gym.env instance.
|
||||||
|
|||||||
+30
-36
@@ -26,65 +26,59 @@ from lerobot.configs.types import FeatureType, PolicyFeature
|
|||||||
from lerobot.envs.configs import EnvConfig
|
from lerobot.envs.configs import EnvConfig
|
||||||
from lerobot.utils.utils import get_channel_first_image_shape
|
from lerobot.utils.utils import get_channel_first_image_shape
|
||||||
|
|
||||||
|
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||||
def preprocess_observation(
|
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
||||||
observations: dict[str, np.ndarray], cfg: dict[str, Any] = None
|
|
||||||
) -> dict[str, Tensor]:
|
|
||||||
"""Convert environment observation to LeRobot format observation.
|
"""Convert environment observation to LeRobot format observation.
|
||||||
Args:
|
Args:
|
||||||
observations: Dictionary of observation batches from a Gym vector environment.
|
observation: Dictionary of observation batches from a Gym vector environment.
|
||||||
cfg: Policy config containing expected feature keys.
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of observation batches with keys renamed to match policy expectations.
|
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||||
"""
|
"""
|
||||||
|
# map to expected inputs for the policy
|
||||||
return_observations = {}
|
return_observations = {}
|
||||||
|
|
||||||
# expected keys from policy
|
|
||||||
policy_img_keys = list(cfg.image_features.keys()) if cfg else ["observation.image"]
|
|
||||||
state_key = cfg.robot_state_feature_key if cfg else "observation.state"
|
|
||||||
|
|
||||||
# handle images
|
|
||||||
if "pixels" in observations:
|
if "pixels" in observations:
|
||||||
if isinstance(observations["pixels"], dict):
|
if isinstance(observations["pixels"], dict):
|
||||||
env_img_keys = list(observations["pixels"].keys())
|
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||||
imgs = observations["pixels"]
|
|
||||||
else:
|
else:
|
||||||
env_img_keys = ["pixels"]
|
imgs = {"observation.image": observations["pixels"]}
|
||||||
imgs = {"pixels": observations["pixels"]}
|
|
||||||
|
|
||||||
# build rename map env_key -> policy_key
|
|
||||||
rename_map = dict(zip(env_img_keys, policy_img_keys, strict=False))
|
|
||||||
|
|
||||||
for imgkey, img in imgs.items():
|
for imgkey, img in imgs.items():
|
||||||
target_key = rename_map.get(imgkey, imgkey)
|
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||||
|
|
||||||
img = torch.from_numpy(img)
|
img = torch.from_numpy(img)
|
||||||
|
|
||||||
# sanity checks
|
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
|
||||||
|
# This is the case for human-in-the-loop RL where there is only one environment.
|
||||||
|
if img.ndim == 3:
|
||||||
|
img = img.unsqueeze(0)
|
||||||
|
# sanity check that images are channel last
|
||||||
_, h, w, c = img.shape
|
_, h, w, c = img.shape
|
||||||
assert c < h and c < w, f"expect channel last images, got {img.shape=}"
|
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||||
assert img.dtype == torch.uint8, f"expect torch.uint8, got {img.dtype=}"
|
|
||||||
|
|
||||||
# channel last → channel first, normalize
|
# sanity check that images are uint8
|
||||||
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
|
|
||||||
|
# convert to channel first of type float32 in range [0,1]
|
||||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||||
img = img.float() / 255.0
|
img = img.type(torch.float32)
|
||||||
|
img /= 255
|
||||||
|
|
||||||
return_observations[target_key] = img
|
return_observations[imgkey] = img
|
||||||
|
|
||||||
# handle state
|
|
||||||
if "environment_state" in observations:
|
if "environment_state" in observations:
|
||||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
env_state = torch.from_numpy(observations["environment_state"]).float()
|
||||||
observations["environment_state"]
|
if env_state.dim() == 1:
|
||||||
).float()
|
env_state = env_state.unsqueeze(0)
|
||||||
|
|
||||||
return_observations[state_key] = torch.from_numpy(observations["agent_pos"]).float()
|
return_observations["observation.environment_state"] = env_state
|
||||||
|
|
||||||
if "task" in observations:
|
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||||
return_observations["task"] = observations["task"]
|
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||||
|
if agent_pos.dim() == 1:
|
||||||
|
agent_pos = agent_pos.unsqueeze(0)
|
||||||
|
return_observations["observation.state"] = agent_pos
|
||||||
|
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ import einops
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from termcolor import colored
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
@@ -73,6 +74,7 @@ from lerobot.policies.factory import make_policy
|
|||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import get_device_from_parameters
|
from lerobot.policies.utils import get_device_from_parameters
|
||||||
from lerobot.utils.io_utils import write_video
|
from lerobot.utils.io_utils import write_video
|
||||||
|
from lerobot.utils.random_utils import set_seed
|
||||||
from lerobot.utils.utils import (
|
from lerobot.utils.utils import (
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
init_logging,
|
init_logging,
|
||||||
@@ -146,8 +148,7 @@ def rollout(
|
|||||||
check_env_attributes_and_types(env)
|
check_env_attributes_and_types(env)
|
||||||
while not np.all(done) and step < max_steps:
|
while not np.all(done) and step < max_steps:
|
||||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||||
# observation = preprocess_observation(observation)
|
observation = preprocess_observation(observation)
|
||||||
observation = preprocess_observation(observation, cfg=policy.config)
|
|
||||||
if return_observations:
|
if return_observations:
|
||||||
all_observations.append(deepcopy(observation))
|
all_observations.append(deepcopy(observation))
|
||||||
|
|
||||||
@@ -459,24 +460,8 @@ def _compile_episode_data(
|
|||||||
|
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
def set_global_seed(seed):
|
|
||||||
"""Set seed for reproducibility."""
|
|
||||||
import random
|
|
||||||
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def log_output_dir(out_dir):
|
|
||||||
logging.info("Output dir:" + f" {out_dir}")
|
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def eval(cfg: EvalPipelineConfig):
|
def eval_main(cfg: EvalPipelineConfig):
|
||||||
logging.info(pformat(asdict(cfg)))
|
logging.info(pformat(asdict(cfg)))
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
@@ -484,9 +469,9 @@ def eval(cfg: EvalPipelineConfig):
|
|||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
set_global_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
|
|
||||||
log_output_dir(cfg.output_dir)
|
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||||
|
|
||||||
logging.info("Making environment.")
|
logging.info("Making environment.")
|
||||||
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||||
@@ -494,11 +479,9 @@ def eval(cfg: EvalPipelineConfig):
|
|||||||
logging.info("Making policy.")
|
logging.info("Making policy.")
|
||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
# device=device,
|
|
||||||
env_cfg=cfg.env,
|
env_cfg=cfg.env,
|
||||||
)
|
)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
if cfg.env.multitask_eval:
|
if cfg.env.multitask_eval:
|
||||||
info = eval_policy_multitask(
|
info = eval_policy_multitask(
|
||||||
@@ -663,4 +646,4 @@ def eval_policy_multitask(
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
init_logging()
|
init_logging()
|
||||||
eval()
|
eval_main()
|
||||||
|
|||||||
@@ -186,7 +186,6 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
train_metrics = {
|
train_metrics = {
|
||||||
"loss": AverageMeter("loss", ":.3f"),
|
"loss": AverageMeter("loss", ":.3f"),
|
||||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||||
@@ -263,15 +262,14 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||||
)
|
)
|
||||||
aggregated = eval_info["overall"]["aggregated"]
|
aggregated = eval_info["overall"]["aggregated"]
|
||||||
# Print per-suite stats
|
# Print per-suite stats, log?
|
||||||
for task_group, task_group_info in eval_info.items():
|
for task_group, task_group_info in eval_info.items():
|
||||||
if task_group == "overall":
|
if task_group == "overall":
|
||||||
continue # Skip the overall stats since we already printed it
|
continue # Skip the overall stats since we already printed it
|
||||||
print(f"\nAggregated Metrics for {task_group}:")
|
print(f"\nAggregated Metrics for {task_group}:")
|
||||||
print(task_group_info["aggregated"])
|
print(task_group_info["aggregated"])
|
||||||
breakpoint()
|
breakpoint()
|
||||||
else:
|
else:
|
||||||
print("START EVAL")
|
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy(
|
||||||
eval_env,
|
eval_env,
|
||||||
policy,
|
policy,
|
||||||
@@ -280,9 +278,8 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
max_episodes_rendered=4,
|
max_episodes_rendered=4,
|
||||||
start_seed=cfg.seed,
|
start_seed=cfg.seed,
|
||||||
)
|
)
|
||||||
breakpoint()
|
|
||||||
aggregated = eval_info["aggregated"]
|
aggregated = eval_info["aggregated"]
|
||||||
print("END EVAL")
|
breakpoint()
|
||||||
|
|
||||||
eval_metrics = {
|
eval_metrics = {
|
||||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||||
|
|||||||
Reference in New Issue
Block a user