Merge branch 'main' into feat/decouple_record_script

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Steven Palma
2026-04-19 22:48:08 +02:00
committed by GitHub
15 changed files with 240 additions and 153 deletions
@@ -52,6 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
)
batch = next(iter(dataloader))
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
batch = preprocessor(batch)
loss, output_dict = policy.forward(batch)
@@ -82,6 +85,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
# indicating padding (those ending with "_is_pad")
dataset.reader.delta_indices = None
batch = next(iter(dataloader))
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
obs = {}
for k in batch:
# TODO: regenerate the safetensors
+29
View File
@@ -454,6 +454,35 @@ def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory):
)
def test_cleanup_interrupted_episode_removes_image_temp_dirs(tmp_path, empty_lerobot_dataset_factory):
"""Verify interrupted episode cleanup removes temporary image directories for both image and video features."""
features = {
"image": {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]},
"video": {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]},
}
ds = empty_lerobot_dataset_factory(
root=tmp_path / "interrupted", features=features, streaming_encoding=False
)
# Add one frame without saving episode simulating an interruption
ds.add_frame(
{
"image": np.random.rand(*DUMMY_CHW),
"video": np.random.rand(*DUMMY_HWC),
"task": "Dummy task",
}
)
img_dir = ds.writer._get_image_file_dir(0, "image")
vid_img_dir = ds.writer._get_image_file_dir(0, "video")
# Precondition: both temp dirs exist after add_frame.
assert img_dir.exists()
assert vid_img_dir.exists()
ds.writer.cleanup_interrupted_episode(episode_index=0)
assert not img_dir.exists(), "image temp dir leaked after cleanup_interrupted_episode"
assert not vid_img_dir.exists(), "video temp dir leaked after cleanup_interrupted_episode"
def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
"""Verify temporary image directories are removed appropriately when both image and video features are present."""
image_key = "image"
+2
View File
@@ -196,6 +196,8 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
for key in batch:
if isinstance(batch[key], torch.Tensor):
if batch[key].dtype == torch.uint8:
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy (and test that it does not mutate the batch)