mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
match state and exoected obs dimensions
This commit is contained in:
@@ -42,7 +42,6 @@ from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_featur
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline, make_default_processors
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import predict_action
|
||||
@@ -319,18 +318,17 @@ def run_ee_inference_loop(
|
||||
ee_state = joints_to_ee(joint_state.copy())
|
||||
|
||||
# 3. Build observation frame with EE state for policy input
|
||||
# Get expected state dimension from policy's input features
|
||||
state_feature = policy.config.input_features.get("observation.state")
|
||||
expected_dim = state_feature.shape[0] if state_feature else None
|
||||
|
||||
# Build state array from EE values (sorted keys)
|
||||
ee_keys = sorted(ee_state.keys())
|
||||
# Filter to only EE keys (FK may include other keys in output)
|
||||
# Expected: left_ee.{x,y,z,wx,wy,wz,gripper_pos}, right_ee.{...} = 14 total
|
||||
ee_keys = sorted([k for k in ee_state.keys() if "_ee." in k])
|
||||
ee_values = [ee_state[k] for k in ee_keys]
|
||||
|
||||
# Truncate to match expected dimension (FK may output more than policy expects)
|
||||
if expected_dim and len(ee_values) > expected_dim:
|
||||
ee_values = ee_values[:expected_dim]
|
||||
ee_keys = ee_keys[:expected_dim]
|
||||
# Debug: print on first step
|
||||
if step == 0:
|
||||
print(f" FK output keys ({len(ee_keys)}): {ee_keys}")
|
||||
state_feature = policy.config.input_features.get("observation.state")
|
||||
if state_feature:
|
||||
print(f" Policy expects state dim: {state_feature.shape[0]}")
|
||||
|
||||
# Store current EE position for relative action conversion (using same order)
|
||||
current_ee_pos = torch.tensor(ee_values)
|
||||
@@ -364,8 +362,12 @@ def run_ee_inference_loop(
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
|
||||
# 5. Convert action tensor to dict
|
||||
ee_action = make_robot_action(action_tensor, dataset.features)
|
||||
# 5. Convert action tensor to dict using EE keys (not joint keys from eval dataset)
|
||||
action_tensor = action_tensor.squeeze(0).cpu()
|
||||
while action_tensor.dim() > 1:
|
||||
action_tensor = action_tensor[0]
|
||||
# Use the same EE keys we used for state (truncated to match policy's action dim)
|
||||
ee_action = {ee_keys[i]: float(action_tensor[i]) for i in range(len(action_tensor))}
|
||||
|
||||
# 6. Convert relative action back to absolute if needed
|
||||
if use_relative_actions:
|
||||
@@ -515,17 +517,17 @@ def main():
|
||||
)
|
||||
print(" Dataset created")
|
||||
|
||||
# Load policy
|
||||
# Load policy directly using from_pretrained to preserve original EE features
|
||||
# (make_policy would overwrite output_features with joint features from eval dataset)
|
||||
print(f"\n[4/5] Loading policy from {HF_MODEL_ID}...")
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
|
||||
# Pass dataset meta for policy creation (required by make_policy)
|
||||
policy = make_policy(policy_config, ds_meta=dataset.meta)
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_cls = get_policy_class(policy_config.type)
|
||||
policy = policy_cls.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Load preprocessor/postprocessor from pretrained model
|
||||
# DO NOT pass dataset_stats - let it load from pretrained model
|
||||
# (evaluation dataset has different features than training dataset)
|
||||
# (uses the trained EE features, not joint features from eval dataset)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy.config,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
@@ -534,6 +536,8 @@ def main():
|
||||
},
|
||||
)
|
||||
print(" Policy loaded")
|
||||
print(f" State dim: {policy.config.input_features['observation.state'].shape[0]}")
|
||||
print(f" Action dim: {policy.config.output_features['action'].shape[0]}")
|
||||
|
||||
# Auto-detect relative action/state settings from checkpoint
|
||||
relative_normalizer, use_relative_actions, use_relative_state = load_relative_config(HF_MODEL_ID)
|
||||
|
||||
Reference in New Issue
Block a user