mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
chore (type): add typing for multiprocess envs
This commit is contained in:
@@ -43,10 +43,10 @@ class TransitionIndex(IntEnum):
|
|||||||
# (observation, action, reward, done, truncated, info, complementary_data)
|
# (observation, action, reward, done, truncated, info, complementary_data)
|
||||||
EnvTransition = Tuple[
|
EnvTransition = Tuple[
|
||||||
dict[str, Any] | None, # observation
|
dict[str, Any] | None, # observation
|
||||||
Any | None, # action
|
Any | torch.Tensor | None, # action
|
||||||
float | None, # reward
|
float | torch.Tensor | None, # reward
|
||||||
bool | None, # done
|
bool | torch.Tensor | None, # done
|
||||||
bool | None, # truncated
|
bool | torch.Tensor | None, # truncated
|
||||||
Dict[str, Any] | None, # info
|
Dict[str, Any] | None, # info
|
||||||
Dict[str, Any] | None, # complementary_data
|
Dict[str, Any] | None, # complementary_data
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -159,12 +159,13 @@ def rollout(
|
|||||||
check_env_attributes_and_types(env)
|
check_env_attributes_and_types(env)
|
||||||
while not np.all(done):
|
while not np.all(done):
|
||||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||||
transition = (observation, None, None, None, None, info, None)
|
transition = (observation, None, None, None, None, None, None)
|
||||||
processed_transition = obs_processor(transition)
|
processed_transition = obs_processor(transition)
|
||||||
observation = processed_transition[TransitionIndex.OBSERVATION]
|
observation = processed_transition[TransitionIndex.OBSERVATION]
|
||||||
if return_observations:
|
if return_observations:
|
||||||
all_observations.append(deepcopy(observation))
|
all_observations.append(deepcopy(observation))
|
||||||
|
|
||||||
|
# TODO(azouitine): Move this in processor side
|
||||||
observation = {
|
observation = {
|
||||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user