chore (type): add typing for multiprocess envs

This commit is contained in:
Adil Zouitine
2025-07-06 22:18:19 +02:00
parent 83a4338f8b
commit 1c56779dd9
2 changed files with 6 additions and 5 deletions
+4 -4
View File
@@ -43,10 +43,10 @@ class TransitionIndex(IntEnum):
# (observation, action, reward, done, truncated, info, complementary_data) # (observation, action, reward, done, truncated, info, complementary_data)
EnvTransition = Tuple[ EnvTransition = Tuple[
dict[str, Any] | None, # observation dict[str, Any] | None, # observation
Any | None, # action Any | torch.Tensor | None, # action
float | None, # reward float | torch.Tensor | None, # reward
bool | None, # done bool | torch.Tensor | None, # done
bool | None, # truncated bool | torch.Tensor | None, # truncated
Dict[str, Any] | None, # info Dict[str, Any] | None, # info
Dict[str, Any] | None, # complementary_data Dict[str, Any] | None, # complementary_data
] ]
+2 -1
View File
@@ -159,12 +159,13 @@ def rollout(
check_env_attributes_and_types(env) check_env_attributes_and_types(env)
while not np.all(done): while not np.all(done):
# Numpy array to tensor and changing dictionary keys to LeRobot policy format. # Numpy array to tensor and changing dictionary keys to LeRobot policy format.
transition = (observation, None, None, None, None, info, None) transition = (observation, None, None, None, None, None, None)
processed_transition = obs_processor(transition) processed_transition = obs_processor(transition)
observation = processed_transition[TransitionIndex.OBSERVATION] observation = processed_transition[TransitionIndex.OBSERVATION]
if return_observations: if return_observations:
all_observations.append(deepcopy(observation)) all_observations.append(deepcopy(observation))
# TODO(azouitine): Move this in processor side
observation = { observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
} }