mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
Merge branch 'main' into feat/decouple_record_script
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -52,6 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
|
||||
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||
batch = preprocessor(batch)
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
@@ -82,6 +85,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
# indicating padding (those ending with "_is_pad")
|
||||
dataset.reader.delta_indices = None
|
||||
batch = next(iter(dataloader))
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
|
||||
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||
obs = {}
|
||||
for k in batch:
|
||||
# TODO: regenerate the safetensors
|
||||
|
||||
@@ -454,6 +454,35 @@ def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
)
|
||||
|
||||
|
||||
def test_cleanup_interrupted_episode_removes_image_temp_dirs(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Verify interrupted episode cleanup removes temporary image directories for both image and video features."""
|
||||
features = {
|
||||
"image": {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]},
|
||||
"video": {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]},
|
||||
}
|
||||
ds = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "interrupted", features=features, streaming_encoding=False
|
||||
)
|
||||
# Add one frame without saving episode simulating an interruption
|
||||
ds.add_frame(
|
||||
{
|
||||
"image": np.random.rand(*DUMMY_CHW),
|
||||
"video": np.random.rand(*DUMMY_HWC),
|
||||
"task": "Dummy task",
|
||||
}
|
||||
)
|
||||
img_dir = ds.writer._get_image_file_dir(0, "image")
|
||||
vid_img_dir = ds.writer._get_image_file_dir(0, "video")
|
||||
# Precondition: both temp dirs exist after add_frame.
|
||||
assert img_dir.exists()
|
||||
assert vid_img_dir.exists()
|
||||
|
||||
ds.writer.cleanup_interrupted_episode(episode_index=0)
|
||||
|
||||
assert not img_dir.exists(), "image temp dir leaked after cleanup_interrupted_episode"
|
||||
assert not vid_img_dir.exists(), "video temp dir leaked after cleanup_interrupted_episode"
|
||||
|
||||
|
||||
def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Verify temporary image directories are removed appropriately when both image and video features are present."""
|
||||
image_key = "image"
|
||||
|
||||
@@ -196,6 +196,8 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
if batch[key].dtype == torch.uint8:
|
||||
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
|
||||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
|
||||
Reference in New Issue
Block a user