From 1c56779dd9ed6c24c07fba7e17fd44f5ce9b10f1 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Sun, 6 Jul 2025 22:18:19 +0200 Subject: [PATCH] chore (type): add typing for multiprocess envs --- src/lerobot/processor/pipeline.py | 8 ++++---- src/lerobot/scripts/eval.py | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index a7181dcc2..5e5f4c177 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -43,10 +43,10 @@ class TransitionIndex(IntEnum): # (observation, action, reward, done, truncated, info, complementary_data) EnvTransition = Tuple[ dict[str, Any] | None, # observation - Any | None, # action - float | None, # reward - bool | None, # done - bool | None, # truncated + Any | torch.Tensor | None, # action + float | torch.Tensor | None, # reward + bool | torch.Tensor | None, # done + bool | torch.Tensor | None, # truncated Dict[str, Any] | None, # info Dict[str, Any] | None, # complementary_data ] diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index c80da8138..7ea4a8995 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -159,12 +159,13 @@ def rollout( check_env_attributes_and_types(env) while not np.all(done): # Numpy array to tensor and changing dictionary keys to LeRobot policy format. - transition = (observation, None, None, None, None, info, None) + transition = (observation, None, None, None, None, None, None) processed_transition = obs_processor(transition) observation = processed_transition[TransitionIndex.OBSERVATION] if return_observations: all_observations.append(deepcopy(observation)) + # TODO(azouitine): Move this in processor side observation = { key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation }