diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 2cf9efcfe..e5c5d6a4c 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -97,6 +97,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: policy_key = env_cfg.features_map[key] policy_features[policy_key] = feature + return policy_features diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 4b8eeffd1..2d51a3881 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -156,6 +156,7 @@ def make_policy( "by default without stats from a dataset." ) features = env_to_policy_features(env_cfg) + cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} kwargs["config"] = cfg @@ -168,6 +169,7 @@ def make_policy( else: # Make a fresh policy. policy = policy_cls(**kwargs) + policy.to(cfg.device) assert isinstance(policy, nn.Module)