mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
chore (type): add typing for multiprocess envs
This commit is contained in:
@@ -43,10 +43,10 @@ class TransitionIndex(IntEnum):
|
||||
# (observation, action, reward, done, truncated, info, complementary_data)
|
||||
EnvTransition = Tuple[
|
||||
dict[str, Any] | None, # observation
|
||||
Any | None, # action
|
||||
float | None, # reward
|
||||
bool | None, # done
|
||||
bool | None, # truncated
|
||||
Any | torch.Tensor | None, # action
|
||||
float | torch.Tensor | None, # reward
|
||||
bool | torch.Tensor | None, # done
|
||||
bool | torch.Tensor | None, # truncated
|
||||
Dict[str, Any] | None, # info
|
||||
Dict[str, Any] | None, # complementary_data
|
||||
]
|
||||
|
||||
@@ -159,12 +159,13 @@ def rollout(
|
||||
check_env_attributes_and_types(env)
|
||||
while not np.all(done):
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
transition = (observation, None, None, None, None, info, None)
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processed_transition = obs_processor(transition)
|
||||
observation = processed_transition[TransitionIndex.OBSERVATION]
|
||||
if return_observations:
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
# TODO(azouitine): Move this in processor side
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user