mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
Added the capability to record a dataset
This commit is contained in:
@@ -27,6 +27,7 @@ import torchvision.transforms.functional as F # noqa: N812
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
@@ -279,7 +280,7 @@ class RobotEnv(gym.Env):
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("joint_velocity_processor_")
|
||||
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||
class JointVelocityProcessor:
|
||||
"""Add joint velocity information to observations.
|
||||
|
||||
@@ -589,18 +590,18 @@ class InterventionActionProcessor:
|
||||
new_transition[TransitionKey.ACTION] = teleop_action_tensor
|
||||
|
||||
# Handle episode termination
|
||||
if terminate_episode:
|
||||
new_transition[TransitionKey.DONE] = True
|
||||
if success:
|
||||
new_transition[TransitionKey.REWARD] = 1.0
|
||||
new_transition[TransitionKey.DONE] = bool(terminate_episode)
|
||||
new_transition[TransitionKey.REWARD] = float(success)
|
||||
|
||||
# Update info with intervention metadata
|
||||
info = new_transition.get(TransitionKey.INFO, {})
|
||||
info["is_intervention"] = is_intervention
|
||||
info["action_intervention"] = new_transition[TransitionKey.ACTION]
|
||||
info["rerecord_episode"] = rerecord_episode
|
||||
info["next.success"] = success if terminate_episode else info.get("next.success", False)
|
||||
new_transition[TransitionKey.INFO] = info
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] = new_transition[
|
||||
TransitionKey.ACTION
|
||||
]
|
||||
|
||||
return new_transition
|
||||
|
||||
@@ -797,10 +798,49 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
transition = env_processor(transition)
|
||||
|
||||
if cfg.mode == "record":
|
||||
action_features = teleop_device.action_features
|
||||
features = {
|
||||
"action": action_features,
|
||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
"complementary_info.discrete_penalty": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": ["discrete_penalty"],
|
||||
},
|
||||
}
|
||||
|
||||
for key, value in transition[TransitionKey.OBSERVATION].items():
|
||||
if key == "observation.state":
|
||||
features[key] = {
|
||||
"dtype": "float32",
|
||||
"shape": value.squeeze(0).shape,
|
||||
"names": None,
|
||||
}
|
||||
if "image" in key:
|
||||
features[key] = {
|
||||
"dtype": "video",
|
||||
"shape": value.squeeze(0).shape,
|
||||
"names": ["channels", "height", "width"],
|
||||
}
|
||||
|
||||
# Create dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.repo_id,
|
||||
cfg.fps,
|
||||
root=cfg.dataset_root,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
image_writer_processes=0,
|
||||
features=features,
|
||||
)
|
||||
|
||||
episode_idx = 0
|
||||
episode_step = 0
|
||||
episode_start_time = time.perf_counter()
|
||||
|
||||
while True:
|
||||
while episode_idx < cfg.num_episodes:
|
||||
step_start_time = time.perf_counter()
|
||||
|
||||
# Get teleoperation action and extra signals
|
||||
@@ -827,14 +867,20 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
||||
|
||||
# Extract processed action and metadata
|
||||
processed_action = processed_action_transition[TransitionKey.ACTION]
|
||||
action_info = processed_action_transition.get(TransitionKey.INFO, {})
|
||||
terminate_episode = processed_action_transition.get(TransitionKey.DONE, False)
|
||||
|
||||
# Step environment with processed action
|
||||
obs, reward, terminated, truncated, info = env.step(processed_action)
|
||||
|
||||
reward = reward + processed_action_transition[TransitionKey.REWARD]
|
||||
|
||||
# Process new observation
|
||||
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
|
||||
complementary_data = {
|
||||
"raw_joint_positions": info.pop("raw_joint_positions"),
|
||||
**processed_action_transition[TransitionKey.COMPLEMENTARY_DATA],
|
||||
}
|
||||
info.update(processed_action_transition[TransitionKey.INFO])
|
||||
|
||||
transition = create_transition(
|
||||
observation=obs,
|
||||
action=processed_action,
|
||||
@@ -848,14 +894,27 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
||||
terminated = transition.get(TransitionKey.DONE, False)
|
||||
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
if cfg.mode == "record":
|
||||
observations = {k: v.squeeze(0) for k, v in transition[TransitionKey.OBSERVATION].items()}
|
||||
frame = {
|
||||
**observations,
|
||||
"action": transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"],
|
||||
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||
"next.done": np.array([terminated or truncated], dtype=bool),
|
||||
"complementary_info.discrete_penalty": np.array(
|
||||
[transition[TransitionKey.COMPLEMENTARY_DATA]["discrete_penalty"]], dtype=np.float32
|
||||
),
|
||||
}
|
||||
dataset.add_frame(frame, task=cfg.task)
|
||||
|
||||
episode_step += 1
|
||||
|
||||
# Handle episode termination
|
||||
if terminated or truncated or terminate_episode:
|
||||
episode_end_reason = "success" if action_info.get("next.success", False) else "terminated"
|
||||
episode_time = time.perf_counter() - episode_start_time
|
||||
print(f"Episode ended ({episode_end_reason}) after {episode_step} steps in {episode_time:.1f}s")
|
||||
print(f"Rerecord episode: {action_info.get('rerecord_episode', False)}")
|
||||
logging.info(
|
||||
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
|
||||
)
|
||||
|
||||
# Reset for new episode
|
||||
obs, info = env.reset()
|
||||
@@ -867,11 +926,24 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
||||
transition = env_processor(transition)
|
||||
|
||||
episode_step = 0
|
||||
episode_start_time = time.perf_counter()
|
||||
episode_idx += 1
|
||||
|
||||
if cfg.mode == "record":
|
||||
if transition[TransitionKey.INFO].get("rerecord_episode", False):
|
||||
logging.info(f"Re-recording episode {episode_idx}")
|
||||
dataset.clear_episode_buffer()
|
||||
episode_idx -= 1
|
||||
else:
|
||||
logging.info(f"Saving episode {episode_idx}")
|
||||
dataset.save_episode()
|
||||
|
||||
# Maintain fps timing
|
||||
busy_wait(dt - (time.perf_counter() - step_start_time))
|
||||
|
||||
if cfg.mode == "record" and cfg.push_to_hub:
|
||||
logging.info("Pushing dataset to hub")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: EnvConfig):
|
||||
@@ -914,7 +986,6 @@ def main(cfg: EnvConfig):
|
||||
print("Environment processor:", env_processor)
|
||||
print("Action processor:", action_processor)
|
||||
|
||||
# Run the control loop
|
||||
control_loop(env, env_processor, action_processor, teleop_device, cfg)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user