chore (type): add typing for multiprocess envs

This commit is contained in:
Adil Zouitine
2025-07-06 22:18:19 +02:00
parent 83a4338f8b
commit 1c56779dd9
2 changed files with 6 additions and 5 deletions
+4 -4
View File
@@ -43,10 +43,10 @@ class TransitionIndex(IntEnum):
# (observation, action, reward, done, truncated, info, complementary_data)
EnvTransition = Tuple[
dict[str, Any] | None, # observation
Any | None, # action
float | None, # reward
bool | None, # done
bool | None, # truncated
Any | torch.Tensor | None, # action
float | torch.Tensor | None, # reward
bool | torch.Tensor | None, # done
bool | torch.Tensor | None, # truncated
Dict[str, Any] | None, # info
Dict[str, Any] | None, # complementary_data
]
+2 -1
View File
@@ -159,12 +159,13 @@ def rollout(
check_env_attributes_and_types(env)
while not np.all(done):
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
transition = (observation, None, None, None, None, info, None)
transition = (observation, None, None, None, None, None, None)
processed_transition = obs_processor(transition)
observation = processed_transition[TransitionIndex.OBSERVATION]
if return_observations:
all_observations.append(deepcopy(observation))
# TODO(azouitine): Move this in processor side
observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}