mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
more changes
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
lerobot-eval \
|
||||
--policy.path="/raid/jade/models/xvla-libero-new_migrated" \
|
||||
--policy.path="/raid/jade/models/xvla-lib" \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial \
|
||||
--env.action_type=abs \
|
||||
--env.control_mode=absolute \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--seed=142
|
||||
|
||||
|
||||
@@ -272,7 +272,7 @@ class LiberoEnv(EnvConfig):
|
||||
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
|
||||
}
|
||||
)
|
||||
action_type: str = "rel"
|
||||
control_mode: str = "relative" # or "absolute"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
|
||||
@@ -24,6 +24,8 @@ from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_m
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -39,6 +41,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
|
||||
def make_env_pre_post_processors(
|
||||
env_cfg: EnvConfig,
|
||||
policy_cfg: PreTrainedConfig,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
@@ -61,11 +64,15 @@ def make_env_pre_post_processors(
|
||||
# Preprocessor and Postprocessor steps are Identity for most environments
|
||||
preprocessor_steps: list[ProcessorStep] = []
|
||||
postprocessor_steps: list[ProcessorStep] = []
|
||||
if isinstance(policy_cfg, XVLAConfig):
|
||||
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor_steps.append(LiberoProcessorStep())
|
||||
|
||||
|
||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
|
||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
|
||||
|
||||
@@ -136,7 +143,7 @@ def make_env(
|
||||
init_states=cfg.init_states,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
action_type=cfg.action_type,
|
||||
control_mode=cfg.control_mode,
|
||||
)
|
||||
elif "metaworld" in cfg.type:
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
@@ -87,7 +87,7 @@ AGENT_POS_HIGH = 1000.0
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||
"libero_spatial": 280, # longest training demo has 193 steps
|
||||
"libero_spatial": 800, # longest training demo has 193 steps
|
||||
"libero_object": 280, # longest training demo has 254 steps
|
||||
"libero_goal": 300, # longest training demo has 270 steps
|
||||
"libero_10": 520, # longest training demo has 505 steps
|
||||
@@ -114,6 +114,7 @@ class LiberoEnv(gym.Env):
|
||||
episode_index: int = 0,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
num_steps_wait: int = 10,
|
||||
control_mode: str = "relative",
|
||||
):
|
||||
super().__init__()
|
||||
self.task_id = task_id
|
||||
@@ -148,7 +149,7 @@ class LiberoEnv(gym.Env):
|
||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||
default_steps = 500
|
||||
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||
|
||||
self.control_mode = control_mode
|
||||
images = {}
|
||||
for cam in self.camera_name:
|
||||
images[self.camera_name_mapping[cam]] = spaces.Box(
|
||||
@@ -239,7 +240,7 @@ class LiberoEnv(gym.Env):
|
||||
image = raw_obs[camera_name]
|
||||
images[self.camera_name_mapping[camera_name]] = image
|
||||
|
||||
eef_pos = raw_obs.get("robot0_eef_pos")
|
||||
eef_pos = self._env.robots[0].controller.ee_pos #raw_obs.get("robot0_eef_pos")
|
||||
eef_quat = raw_obs.get("robot0_eef_quat")
|
||||
|
||||
# rotation matrix from controller
|
||||
@@ -296,6 +297,15 @@ class LiberoEnv(gym.Env):
|
||||
# Increasing this value can improve determinism and reproducibility across resets.
|
||||
for _ in range(self.num_steps_wait):
|
||||
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
||||
|
||||
if self.control_mode == "absolute":
|
||||
for robot in self._env.robots:
|
||||
robot.controller.use_delta = False
|
||||
elif self.control_mode == "relative":
|
||||
for robot in self._env.robots:
|
||||
robot.controller.use_delta = True
|
||||
else:
|
||||
raise ValueError(f"Invalid control mode: {self.control_mode}")
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
@@ -343,6 +353,7 @@ def _make_env_fns(
|
||||
camera_names: list[str],
|
||||
init_states: bool,
|
||||
gym_kwargs: Mapping[str, Any],
|
||||
control_mode: str,
|
||||
) -> list[Callable[[], LiberoEnv]]:
|
||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||
|
||||
@@ -355,6 +366,7 @@ def _make_env_fns(
|
||||
camera_name=camera_names,
|
||||
init_states=init_states,
|
||||
episode_index=episode_index,
|
||||
control_mode=control_mode,
|
||||
**local_kwargs,
|
||||
)
|
||||
|
||||
@@ -374,6 +386,7 @@ def create_libero_envs(
|
||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||
init_states: bool = True,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
control_mode: str = "relative",
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""
|
||||
Create vectorized LIBERO environments with a consistent return shape.
|
||||
@@ -409,6 +422,7 @@ def create_libero_envs(
|
||||
suite = _get_suite(suite_name)
|
||||
total = len(suite.tasks)
|
||||
selected = _select_task_ids(total, task_ids_filter)
|
||||
selected = [0]
|
||||
if not selected:
|
||||
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
|
||||
|
||||
@@ -421,6 +435,7 @@ def create_libero_envs(
|
||||
camera_names=camera_names,
|
||||
init_states=init_states,
|
||||
gym_kwargs=gym_kwargs,
|
||||
control_mode=control_mode,
|
||||
)
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
from lerobot.policies.xvla.processor_xvla import (
|
||||
XVLAAddDomainIdProcessorStep,
|
||||
XVLAImageScaleProcessorStep,
|
||||
XVLARotation6DToAxisAngleProcessorStep,
|
||||
make_xvla_pre_post_processors,
|
||||
)
|
||||
@@ -315,7 +315,6 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
return total_loss, log_dict
|
||||
|
||||
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
print("get_action_chunk")
|
||||
inputs = self._build_model_inputs(batch)
|
||||
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
|
||||
actions = self._trim_action_dim(actions)
|
||||
@@ -361,7 +360,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
import safetensors.torch
|
||||
|
||||
# --- Step 1: Load config ---
|
||||
# step 1: load config
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
@@ -377,7 +376,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
instance = cls(config, **kwargs)
|
||||
# --- Step 2: Locate model.safetensors ---
|
||||
# step 2: locate model.safetensors
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, "model.safetensors")
|
||||
@@ -401,13 +400,14 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
|
||||
|
||||
print(f"Loading checkpoint from {model_file}")
|
||||
# step 3: load state dict
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
|
||||
shared_key = "model.vlm.language_model.model.shared.weight"
|
||||
if encoder_key in state_dict:
|
||||
state_dict[shared_key] = state_dict[encoder_key]
|
||||
# or deepcopy
|
||||
# step 5: load into instance
|
||||
# step 4: load into instance
|
||||
missing, unexpected = instance.load_state_dict(state_dict, strict=True)
|
||||
print("Loaded XVLA checkpoint")
|
||||
if missing:
|
||||
@@ -415,7 +415,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
if unexpected:
|
||||
print(f"Unexpected keys: {unexpected}")
|
||||
|
||||
# step 6: finalize
|
||||
# step 5: finalize
|
||||
instance.to(config.device)
|
||||
instance.eval()
|
||||
return instance
|
||||
|
||||
@@ -28,6 +28,7 @@ from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
@@ -36,8 +37,8 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME, OBS_STATE, OBS_IMAGES
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
|
||||
def make_xvla_pre_post_processors(
|
||||
config: XVLAConfig,
|
||||
@@ -89,6 +90,127 @@ def make_xvla_pre_post_processors(
|
||||
|
||||
|
||||
# Custom XVLA processor steps
|
||||
@dataclass
|
||||
class LiberoProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes LIBERO observations into the LeRobot format.
|
||||
|
||||
This step handles the specific observation structure from LIBERO environments,
|
||||
which includes nested robot_state dictionaries and image observations.
|
||||
|
||||
**State Processing:**
|
||||
- Processes the `robot_state` dictionary which contains nested end-effector,
|
||||
gripper, and joint information.
|
||||
- Extracts and concatenates:
|
||||
- End-effector position (3D)
|
||||
- End-effector quaternion converted to axis-angle (3D)
|
||||
- Gripper joint positions (2D)
|
||||
- Maps the concatenated state to `"observation.state"`.
|
||||
|
||||
**Image Processing:**
|
||||
- Rotates images by 180 degrees by flipping both height and width dimensions.
|
||||
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
|
||||
"""
|
||||
|
||||
def _process_observation(self, observation):
|
||||
"""
|
||||
Processes both image and robot_state observations from LIBERO.
|
||||
"""
|
||||
processed_obs = observation.copy()
|
||||
for key in list(processed_obs.keys()):
|
||||
if key.startswith(f"{OBS_IMAGES}."):
|
||||
img = processed_obs[key]
|
||||
|
||||
if key == f"{OBS_IMAGES}.image":
|
||||
# Flip both H and W
|
||||
img = torch.flip(img, dims=[2, 3])
|
||||
|
||||
processed_obs[key] = img
|
||||
# Process robot_state into a flat state vector
|
||||
if "observation.robot_state" in processed_obs:
|
||||
robot_state = processed_obs.pop("observation.robot_state")
|
||||
|
||||
# Extract components
|
||||
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
|
||||
eef_mat = robot_state["eef"]["mat"] # (B, 3, 3)
|
||||
eef_rot6d = self._mat_to_rotate6d(eef_mat) # (B, 6)
|
||||
|
||||
extra = torch.zeros((eef_pos.shape[0], 1), dtype=torch.float32, device=eef_pos.device)
|
||||
|
||||
proprio_state = torch.cat((eef_pos, eef_rot6d, extra), dim=-1) # (B, 10)
|
||||
state = torch.cat((proprio_state, torch.zeros_like(proprio_state)), dim=-1) # (B, 20)
|
||||
# ensure float32
|
||||
state = state.float()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
processed_obs[OBS_STATE] = state
|
||||
return processed_obs
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Transforms feature keys from the LIBERO format to the LeRobot standard.
|
||||
"""
|
||||
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
|
||||
|
||||
# copy over non-STATE features
|
||||
for ft, feats in features.items():
|
||||
if ft != PipelineFeatureType.STATE:
|
||||
new_features[ft] = feats.copy()
|
||||
|
||||
# rebuild STATE features
|
||||
state_feats = {}
|
||||
|
||||
# add our new flattened state
|
||||
state_feats["observation.state"] = PolicyFeature(
|
||||
key="observation.state",
|
||||
shape=(20,),
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
new_features[PipelineFeatureType.STATE] = state_feats
|
||||
|
||||
return new_features
|
||||
|
||||
def _mat_to_rotate6d(self, rot_mats: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert batched rotation matrices (B, 3, 3) into 6D rotation representation (B, 6).
|
||||
|
||||
Args:
|
||||
rot_mats (Tensor): Rotation matrices of shape (B, 3, 3)
|
||||
|
||||
Returns:
|
||||
Tensor: 6D rotation representation, shape (B, 6)
|
||||
|
||||
Raises:
|
||||
TypeError: if input is not a torch tensor
|
||||
ValueError: if shape is not (B, 3, 3)
|
||||
"""
|
||||
|
||||
if not isinstance(rot_mats, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"mat_to_rot6d expects a torch.Tensor, got {type(rot_mats)}"
|
||||
)
|
||||
|
||||
if rot_mats.ndim != 3 or rot_mats.shape[1:] != (3, 3):
|
||||
raise ValueError(
|
||||
f"mat_to_rot6d expects shape (B, 3, 3), got {tuple(rot_mats.shape)}"
|
||||
)
|
||||
|
||||
rot_mats = rot_mats.to(torch.float32)
|
||||
|
||||
col1 = rot_mats[:, :3, 0] # (B, 3)
|
||||
col2 = rot_mats[:, :3, 1] # (B, 3)
|
||||
|
||||
rot6d = torch.cat([col1, col2], dim=-1) # (B, 6)
|
||||
|
||||
return rot6d
|
||||
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -254,3 +376,24 @@ class XVLARotation6DToAxisAngleProcessorStep(ProcessorStep):
|
||||
return {
|
||||
"expected_action_dim": self.expected_action_dim,
|
||||
}
|
||||
|
||||
def make_xvla_libero_pre_post_processors(
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Build the LeRobot processor pipelines for XVLA with LIBERO environment.
|
||||
"""
|
||||
pre_processor_steps: list[ProcessorStep] = []
|
||||
post_processor_steps: list[ProcessorStep] = []
|
||||
pre_processor_steps.extend([LiberoProcessorStep(), XVLAImageScaleProcessorStep(), XVLAAddDomainIdProcessorStep()])
|
||||
post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()])
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=pre_processor_steps,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=post_processor_steps,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -328,7 +328,6 @@ class _NormalizationMixin:
|
||||
if norm_mode == NormalizationMode.IMAGENET:
|
||||
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
|
||||
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
|
||||
|
||||
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
|
||||
while mean.dim() < tensor.dim():
|
||||
mean = mean.unsqueeze(0)
|
||||
|
||||
@@ -175,11 +175,9 @@ def rollout(
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
action = postprocessor(action)
|
||||
|
||||
action_transition = {"action": action}
|
||||
action_transition = env_postprocessor(action_transition)
|
||||
action = action_transition["action"]
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
action_numpy: np.ndarray = action.to("cpu").numpy()
|
||||
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
@@ -533,7 +531,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
)
|
||||
|
||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy_all(
|
||||
|
||||
Reference in New Issue
Block a user