Added the capability to record a dataset

This commit is contained in:
Michel Aractingi
2025-08-02 17:14:14 +02:00
parent 1fdbecad3c
commit 384101731e
+85 -14
View File
@@ -27,6 +27,7 @@ import torchvision.transforms.functional as F # noqa: N812
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.types import PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.envs.configs import EnvConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
@@ -279,7 +280,7 @@ class RobotEnv(gym.Env):
@dataclass
@ProcessorStepRegistry.register("joint_velocity_processor_")
@ProcessorStepRegistry.register("joint_velocity_processor")
class JointVelocityProcessor:
"""Add joint velocity information to observations.
@@ -589,18 +590,18 @@ class InterventionActionProcessor:
new_transition[TransitionKey.ACTION] = teleop_action_tensor
# Handle episode termination
if terminate_episode:
new_transition[TransitionKey.DONE] = True
if success:
new_transition[TransitionKey.REWARD] = 1.0
new_transition[TransitionKey.DONE] = bool(terminate_episode)
new_transition[TransitionKey.REWARD] = float(success)
# Update info with intervention metadata
info = new_transition.get(TransitionKey.INFO, {})
info["is_intervention"] = is_intervention
info["action_intervention"] = new_transition[TransitionKey.ACTION]
info["rerecord_episode"] = rerecord_episode
info["next.success"] = success if terminate_episode else info.get("next.success", False)
new_transition[TransitionKey.INFO] = info
new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] = new_transition[
TransitionKey.ACTION
]
return new_transition
@@ -797,10 +798,49 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(transition)
if cfg.mode == "record":
action_features = teleop_device.action_features
features = {
"action": action_features,
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
"complementary_info.discrete_penalty": {
"dtype": "float32",
"shape": (1,),
"names": ["discrete_penalty"],
},
}
for key, value in transition[TransitionKey.OBSERVATION].items():
if key == "observation.state":
features[key] = {
"dtype": "float32",
"shape": value.squeeze(0).shape,
"names": None,
}
if "image" in key:
features[key] = {
"dtype": "video",
"shape": value.squeeze(0).shape,
"names": ["channels", "height", "width"],
}
# Create dataset
dataset = LeRobotDataset.create(
cfg.repo_id,
cfg.fps,
root=cfg.dataset_root,
use_videos=True,
image_writer_threads=4,
image_writer_processes=0,
features=features,
)
episode_idx = 0
episode_step = 0
episode_start_time = time.perf_counter()
while True:
while episode_idx < cfg.num_episodes:
step_start_time = time.perf_counter()
# Get teleoperation action and extra signals
@@ -827,14 +867,20 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
# Extract processed action and metadata
processed_action = processed_action_transition[TransitionKey.ACTION]
action_info = processed_action_transition.get(TransitionKey.INFO, {})
terminate_episode = processed_action_transition.get(TransitionKey.DONE, False)
# Step environment with processed action
obs, reward, terminated, truncated, info = env.step(processed_action)
reward = reward + processed_action_transition[TransitionKey.REWARD]
# Process new observation
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
complementary_data = {
"raw_joint_positions": info.pop("raw_joint_positions"),
**processed_action_transition[TransitionKey.COMPLEMENTARY_DATA],
}
info.update(processed_action_transition[TransitionKey.INFO])
transition = create_transition(
observation=obs,
action=processed_action,
@@ -848,14 +894,27 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
terminated = transition.get(TransitionKey.DONE, False)
truncated = transition.get(TransitionKey.TRUNCATED, False)
if cfg.mode == "record":
observations = {k: v.squeeze(0) for k, v in transition[TransitionKey.OBSERVATION].items()}
frame = {
**observations,
"action": transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"],
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
"next.done": np.array([terminated or truncated], dtype=bool),
"complementary_info.discrete_penalty": np.array(
[transition[TransitionKey.COMPLEMENTARY_DATA]["discrete_penalty"]], dtype=np.float32
),
}
dataset.add_frame(frame, task=cfg.task)
episode_step += 1
# Handle episode termination
if terminated or truncated or terminate_episode:
episode_end_reason = "success" if action_info.get("next.success", False) else "terminated"
episode_time = time.perf_counter() - episode_start_time
print(f"Episode ended ({episode_end_reason}) after {episode_step} steps in {episode_time:.1f}s")
print(f"Rerecord episode: {action_info.get('rerecord_episode', False)}")
logging.info(
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
)
# Reset for new episode
obs, info = env.reset()
@@ -867,11 +926,24 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
transition = env_processor(transition)
episode_step = 0
episode_start_time = time.perf_counter()
episode_idx += 1
if cfg.mode == "record":
if transition[TransitionKey.INFO].get("rerecord_episode", False):
logging.info(f"Re-recording episode {episode_idx}")
dataset.clear_episode_buffer()
episode_idx -= 1
else:
logging.info(f"Saving episode {episode_idx}")
dataset.save_episode()
# Maintain fps timing
busy_wait(dt - (time.perf_counter() - step_start_time))
if cfg.mode == "record" and cfg.push_to_hub:
logging.info("Pushing dataset to hub")
dataset.push_to_hub()
@parser.wrap()
def main(cfg: EnvConfig):
@@ -914,7 +986,6 @@ def main(cfg: EnvConfig):
print("Environment processor:", env_processor)
print("Action processor:", action_processor)
# Run the control loop
control_loop(env, env_processor, action_processor, teleop_device, cfg)