mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
add multitask
This commit is contained in:
@@ -16,7 +16,7 @@ SAVE_FREQ=10000
|
|||||||
NUM_WORKERS=0
|
NUM_WORKERS=0
|
||||||
|
|
||||||
# model params
|
# model params
|
||||||
POLICY=smolvla
|
POLICY=pi0
|
||||||
USE_AMP=false
|
USE_AMP=false
|
||||||
OPTIMIZER_LR=1e-4
|
OPTIMIZER_LR=1e-4
|
||||||
PEFT_METHOD=lora
|
PEFT_METHOD=lora
|
||||||
@@ -30,11 +30,13 @@ USE_IMAGENET_STATS=false
|
|||||||
ENABLE_IMG_TRANSFORM=true
|
ENABLE_IMG_TRANSFORM=true
|
||||||
MAX_NUM_IMAGES=2
|
MAX_NUM_IMAGES=2
|
||||||
MAX_IMAGE_DIM=1024
|
MAX_IMAGE_DIM=1024
|
||||||
|
unset LEROBOT_HOME
|
||||||
|
unset HF_LEROBOT_HOME
|
||||||
|
|
||||||
echo -e "\033[1;33m[WARNING]\033[0m LIBERO is not yet fully supported in this PR!"
|
echo -e "\033[1;33m[WARNING]\033[0m LIBERO is not yet fully supported in this PR!"
|
||||||
|
|
||||||
# launch
|
# launch
|
||||||
PYTORCH_ENABLE_MPS_FALLBACK=1 DEVICE=cpu python src/lerobot/scripts/train.py \
|
python src/lerobot/scripts/train.py \
|
||||||
--policy.device=cpu \
|
|
||||||
--policy.type=$POLICY \
|
--policy.type=$POLICY \
|
||||||
--dataset.repo_id=$REPO_ID \
|
--dataset.repo_id=$REPO_ID \
|
||||||
--env.type=libero \
|
--env.type=libero \
|
||||||
@@ -45,11 +47,7 @@ PYTORCH_ENABLE_MPS_FALLBACK=1 DEVICE=cpu python src/lerobot/scripts/train.py \
|
|||||||
--eval_freq=$EVAL_FREQ \
|
--eval_freq=$EVAL_FREQ \
|
||||||
--save_freq=$SAVE_FREQ \
|
--save_freq=$SAVE_FREQ \
|
||||||
--num_workers=$NUM_WORKERS \
|
--num_workers=$NUM_WORKERS \
|
||||||
--policy.max_action_dim=$MAX_ACTION_DIM \
|
|
||||||
--policy.max_state_dim=$MAX_STATE_DIM \
|
|
||||||
--policy.use_amp=$USE_AMP \
|
|
||||||
--policy.optimizer_lr=$OPTIMIZER_LR \
|
|
||||||
--policy.load_vlm_weights=$LOAD_VLM_WEIGHTS \
|
|
||||||
--policy.repo_id=$VLM_REPO_ID \
|
--policy.repo_id=$VLM_REPO_ID \
|
||||||
--env.multitask_eval=False \
|
--env.multitask_eval=True \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
|
--eval.n_episodes=1 \
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Example evaluation script for LeRobot policies
|
||||||
|
unset LEROBOT_HOME
|
||||||
|
unset HF_LEROBOT_HOME
|
||||||
|
# === CONFIGURATION ===
|
||||||
|
POLICY_PATH="ganatrask/lerobot-pi0-libero-object" # or outputs/train/.../pretrained_model
|
||||||
|
TASK=libero_object
|
||||||
|
ENV_TYPE="libero"
|
||||||
|
BATCH_SIZE=1
|
||||||
|
N_EPISODES=1
|
||||||
|
USE_AMP=false
|
||||||
|
DEVICE=cuda
|
||||||
|
|
||||||
|
# === RUN EVALUATION ===
|
||||||
|
python src/lerobot/scripts/eval.py \
|
||||||
|
--policy.path="$POLICY_PATH" \
|
||||||
|
--env.type="$ENV_TYPE" \
|
||||||
|
--eval.batch_size="$BATCH_SIZE" \
|
||||||
|
--eval.n_episodes="$N_EPISODES" \
|
||||||
|
--env.multitask_eval=False \
|
||||||
|
--env.task=$TASK \
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 46 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 49 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 48 KiB |
@@ -19,10 +19,12 @@ from huggingface_hub.constants import HF_HOME
|
|||||||
|
|
||||||
OBS_ENV_STATE = "observation.environment_state"
|
OBS_ENV_STATE = "observation.environment_state"
|
||||||
OBS_STATE = "observation.state"
|
OBS_STATE = "observation.state"
|
||||||
# OBS_IMAGE = "observation.image"
|
OBS_IMAGE = "observation.image"
|
||||||
# OBS_IMAGE_2 = "observation.image2"
|
OBS_IMAGE_2 = "observation.image2"
|
||||||
OBS_IMAGE = "image"
|
OBS_IMAGE = "image"
|
||||||
OBS_IMAGE_2 = "wrist_image"
|
OBS_IMAGE_2 = "image2"
|
||||||
|
# OBS_IMAGE = "image"
|
||||||
|
# OBS_IMAGE_2 = "wrist_image"
|
||||||
OBS_IMAGES = "observation.images"
|
OBS_IMAGES = "observation.images"
|
||||||
ACTION = "action"
|
ACTION = "action"
|
||||||
REWARD = "next.reward"
|
REWARD = "next.reward"
|
||||||
|
|||||||
@@ -295,8 +295,8 @@ class LiberoEnv(EnvConfig):
|
|||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": ACTION,
|
"action": ACTION,
|
||||||
"agent_pos": OBS_STATE,
|
"agent_pos": OBS_STATE,
|
||||||
"pixels/agentview_image": f"{OBS_IMAGE}",
|
"pixels/agentview_image": f"observation.images.{OBS_IMAGE}",
|
||||||
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGE_2}",
|
"pixels/robot0_eye_in_hand_image": f"observation.images.{OBS_IMAGE_2}",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ def create_libero_envs(
|
|||||||
Returns:
|
Returns:
|
||||||
dict[str, dict[str, list[LiberoEnv]]]: keys are task_suite and values are list of LiberoEnv envs.
|
dict[str, dict[str, list[LiberoEnv]]]: keys are task_suite and values are list of LiberoEnv envs.
|
||||||
"""
|
"""
|
||||||
|
print("num envs", n_envs)
|
||||||
|
print("multitask_eval", multitask_eval)
|
||||||
|
print("gym_kwargs", gym_kwargs)
|
||||||
if gym_kwargs is None:
|
if gym_kwargs is None:
|
||||||
gym_kwargs = {}
|
gym_kwargs = {}
|
||||||
|
|
||||||
@@ -45,6 +48,7 @@ def create_libero_envs(
|
|||||||
episode_indices = list(range(n_envs))
|
episode_indices = list(range(n_envs))
|
||||||
elif len(tasks_id) < n_envs and n_envs % len(tasks_id) == 0:
|
elif len(tasks_id) < n_envs and n_envs % len(tasks_id) == 0:
|
||||||
n_repeat = n_envs // len(tasks_id)
|
n_repeat = n_envs // len(tasks_id)
|
||||||
|
print("n_repeat", n_repeat)
|
||||||
episode_indices = []
|
episode_indices = []
|
||||||
for i in range(len(tasks_id)):
|
for i in range(len(tasks_id)):
|
||||||
episode_indices.extend(list(range(n_repeat)))
|
episode_indices.extend(list(range(n_repeat)))
|
||||||
@@ -313,11 +317,9 @@ class LiberoEnv(gym.Env):
|
|||||||
def step(self, action):
|
def step(self, action):
|
||||||
assert action.ndim == 1
|
assert action.ndim == 1
|
||||||
raw_obs, reward, done, info = self._env.step(action)
|
raw_obs, reward, done, info = self._env.step(action)
|
||||||
|
|
||||||
is_success = self._env.check_success()
|
is_success = self._env.check_success()
|
||||||
terminated = done or is_success
|
terminated = done or is_success
|
||||||
info["is_success"] = is_success
|
info["is_success"] = is_success
|
||||||
print(f"[LiberoEnv.step] done={done}, is_success={is_success}, terminated={terminated}")
|
|
||||||
observation = self._format_raw_obs(raw_obs)
|
observation = self._format_raw_obs(raw_obs)
|
||||||
truncated = False
|
truncated = False
|
||||||
# note if it is unable to complete get libero error after many steps
|
# note if it is unable to complete get libero error after many steps
|
||||||
|
|||||||
@@ -97,7 +97,6 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
|||||||
|
|
||||||
policy_key = env_cfg.features_map[key]
|
policy_key = env_cfg.features_map[key]
|
||||||
policy_features[policy_key] = feature
|
policy_features[policy_key] = feature
|
||||||
|
|
||||||
return policy_features
|
return policy_features
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -156,7 +156,6 @@ def make_policy(
|
|||||||
"by default without stats from a dataset."
|
"by default without stats from a dataset."
|
||||||
)
|
)
|
||||||
features = env_to_policy_features(env_cfg)
|
features = env_to_policy_features(env_cfg)
|
||||||
|
|
||||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||||
kwargs["config"] = cfg
|
kwargs["config"] = cfg
|
||||||
|
|||||||
+43
-11
@@ -56,7 +56,7 @@ from copy import deepcopy
|
|||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
import concurrent
|
||||||
import einops
|
import einops
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -156,10 +156,29 @@ def rollout(
|
|||||||
# Infer "task" from attributes of environments.
|
# Infer "task" from attributes of environments.
|
||||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||||
observation = add_envs_task(env, observation)
|
observation = add_envs_task(env, observation)
|
||||||
|
if step % 100 == 0:
|
||||||
|
import imageio.v2 as imageio
|
||||||
|
|
||||||
|
img = observation["observation.images.image"] # (1, 3, 256, 256)
|
||||||
|
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
img = img.detach().cpu().numpy()
|
||||||
|
|
||||||
|
# remove batch → (3, 256, 256)
|
||||||
|
img = img[0]
|
||||||
|
|
||||||
|
# transpose → (256, 256, 3)
|
||||||
|
img = np.transpose(img, (1, 2, 0))
|
||||||
|
|
||||||
|
# scale + convert to uint8
|
||||||
|
img = (img * 255).clip(0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# now works
|
||||||
|
imageio.imwrite(f"obs_{step:06d}.png", img)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation)
|
||||||
|
observation['observation.images.image']
|
||||||
# Convert to CPU / numpy.
|
# Convert to CPU / numpy.
|
||||||
action = action.to("cpu").numpy()
|
action = action.to("cpu").numpy()
|
||||||
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||||
@@ -177,7 +196,12 @@ def rollout(
|
|||||||
successes = [False] * env.num_envs
|
successes = [False] * env.num_envs
|
||||||
|
|
||||||
# Keep track of which environments are done so far.
|
# Keep track of which environments are done so far.
|
||||||
|
# done = terminated | truncated | done
|
||||||
|
#TODO: jadechoghari changed, this is cleaner
|
||||||
done = terminated | truncated | done
|
done = terminated | truncated | done
|
||||||
|
if step + 1 == max_steps:
|
||||||
|
done = np.ones_like(done, dtype=bool)
|
||||||
|
|
||||||
|
|
||||||
all_actions.append(torch.from_numpy(action))
|
all_actions.append(torch.from_numpy(action))
|
||||||
all_rewards.append(torch.from_numpy(reward))
|
all_rewards.append(torch.from_numpy(reward))
|
||||||
@@ -185,7 +209,6 @@ def rollout(
|
|||||||
all_successes.append(torch.tensor(successes))
|
all_successes.append(torch.tensor(successes))
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
print(step)
|
|
||||||
running_success_rate = (
|
running_success_rate = (
|
||||||
# einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean() #TODO: changed by jade
|
# einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean() #TODO: changed by jade
|
||||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "max")
|
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "max")
|
||||||
@@ -254,6 +277,7 @@ def eval_policy(
|
|||||||
# Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
|
# Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
|
||||||
# divisible by env.num_envs we end up discarding some data in the last batch.
|
# divisible by env.num_envs we end up discarding some data in the last batch.
|
||||||
n_batches = n_episodes // env.num_envs + int((n_episodes % env.num_envs) != 0)
|
n_batches = n_episodes // env.num_envs + int((n_episodes % env.num_envs) != 0)
|
||||||
|
print("n_batches", n_batches)
|
||||||
|
|
||||||
# Keep track of some metrics.
|
# Keep track of some metrics.
|
||||||
sum_rewards = []
|
sum_rewards = []
|
||||||
@@ -374,7 +398,7 @@ def eval_policy(
|
|||||||
# Wait till all video rendering threads are done.
|
# Wait till all video rendering threads are done.
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
# Compile eval info.
|
# Compile eval info.
|
||||||
info = {
|
info = {
|
||||||
"per_episode": [
|
"per_episode": [
|
||||||
@@ -403,7 +427,6 @@ def eval_policy(
|
|||||||
"eval_ep_s": (time.time() - start) / n_episodes,
|
"eval_ep_s": (time.time() - start) / n_episodes,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if return_episode_data:
|
if return_episode_data:
|
||||||
info["episodes"] = episode_data
|
info["episodes"] = episode_data
|
||||||
|
|
||||||
@@ -457,13 +480,22 @@ def _compile_episode_data(
|
|||||||
|
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
|
def set_global_seed(seed):
|
||||||
|
"""Set seed for reproducibility."""
|
||||||
|
import random
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
def log_output_dir(out_dir):
|
||||||
|
logging.info("Output dir:"+ f" {out_dir}")
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def eval(cfg: EvalPipelineConfig):
|
def eval(cfg: EvalPipelineConfig):
|
||||||
logging.info(pformat(asdict(cfg)))
|
logging.info(pformat(asdict(cfg)))
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
device = get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
@@ -477,12 +509,12 @@ def eval(cfg: EvalPipelineConfig):
|
|||||||
logging.info("Making policy.")
|
logging.info("Making policy.")
|
||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
device=device,
|
# device=device,
|
||||||
env_cfg=cfg.env,
|
env_cfg=cfg.env,
|
||||||
)
|
)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
if cfg.env.multitask_eval:
|
if cfg.env.multitask_eval:
|
||||||
info = eval_policy_multitask(
|
info = eval_policy_multitask(
|
||||||
env,
|
env,
|
||||||
@@ -555,7 +587,7 @@ def eval_policy_multitask(
|
|||||||
videos_dir,
|
videos_dir,
|
||||||
return_episode_data,
|
return_episode_data,
|
||||||
start_seed,
|
start_seed,
|
||||||
verbose=verbose,
|
# verbose=verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
per_episode = task_result["per_episode"]
|
per_episode = task_result["per_episode"]
|
||||||
@@ -642,4 +674,4 @@ def eval_policy_multitask(
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
init_logging()
|
init_logging()
|
||||||
eval_main()
|
eval()
|
||||||
|
|||||||
@@ -269,7 +269,10 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
continue # Skip the overall stats since we already printed it
|
continue # Skip the overall stats since we already printed it
|
||||||
print(f"\nAggregated Metrics for {task_group}:")
|
print(f"\nAggregated Metrics for {task_group}:")
|
||||||
print(task_group_info["aggregated"])
|
print(task_group_info["aggregated"])
|
||||||
|
breakpoint()
|
||||||
else:
|
else:
|
||||||
|
print("START EVAL")
|
||||||
|
breakpoint()
|
||||||
eval_info = eval_policy(
|
eval_info = eval_policy(
|
||||||
eval_env,
|
eval_env,
|
||||||
policy,
|
policy,
|
||||||
@@ -278,6 +281,8 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
max_episodes_rendered=4,
|
max_episodes_rendered=4,
|
||||||
start_seed=cfg.seed,
|
start_seed=cfg.seed,
|
||||||
)
|
)
|
||||||
|
aggregated = eval_info["aggregated"]
|
||||||
|
print("END EVAL")
|
||||||
|
|
||||||
eval_metrics = {
|
eval_metrics = {
|
||||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||||
@@ -287,9 +292,9 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
eval_tracker = MetricsTracker(
|
eval_tracker = MetricsTracker(
|
||||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
||||||
)
|
)
|
||||||
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
|
eval_tracker.eval_s = aggregated.pop("eval_s")
|
||||||
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
|
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
|
||||||
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
|
eval_tracker.pc_success = aggregated.pop("pc_success")
|
||||||
logging.info(eval_tracker)
|
logging.info(eval_tracker)
|
||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||||
|
|||||||
Reference in New Issue
Block a user