mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-22 09:29:44 +00:00
fix agibot2lerobot and update dirty tasks
This commit is contained in:
@@ -34,10 +34,12 @@ def load_local_dataset(
|
||||
for key in AgiBotWorld_CONFIG["actions"]:
|
||||
action[f"actions.{key}"] = np.array(f["action/" + key.replace(".", "/")], dtype=np.float32)
|
||||
|
||||
# HACK: agibot team forgot to pad some of the values
|
||||
# HACK: agibot team forgot to pad or filter some of the values
|
||||
num_frames = len(next(iter(state.values())))
|
||||
for action_key, action_value in action.items():
|
||||
if action_value.size and len(action_value) != num_frames:
|
||||
if 0 == len(action_value):
|
||||
print("0 action occurs, padding all with zeros later")
|
||||
elif len(action_value) < num_frames:
|
||||
state_key = action_key.replace("actions", "state").replace(".", "/")
|
||||
new_action_value = np.array(f[state_key], dtype=np.float32).copy()
|
||||
action_index_key = "/".join(list(action_key.replace("actions", "action").split(".")[:-1]) + ["index"])
|
||||
@@ -48,6 +50,9 @@ def load_local_dataset(
|
||||
action_index = np.array(f[action_index_key])
|
||||
new_action_value[action_index] = action_value
|
||||
action[action_key] = new_action_value
|
||||
elif len(action_value) > num_frames:
|
||||
print("corrupt data, skipping")
|
||||
return episode_id, [], {"dummy_video": Path("/path/to/no_exist")}
|
||||
|
||||
if save_depth:
|
||||
depth_imgs = load_depths(ob_dir / "depth", "head_depth")
|
||||
|
||||
Reference in New Issue
Block a user