mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
Added the capability to record a dataset
This commit is contained in:
@@ -27,6 +27,7 @@ import torchvision.transforms.functional as F # noqa: N812
|
|||||||
from lerobot.cameras import opencv # noqa: F401
|
from lerobot.cameras import opencv # noqa: F401
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.types import PolicyFeature
|
from lerobot.configs.types import PolicyFeature
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.envs.configs import EnvConfig
|
from lerobot.envs.configs import EnvConfig
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
@@ -279,7 +280,7 @@ class RobotEnv(gym.Env):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register("joint_velocity_processor_")
|
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||||
class JointVelocityProcessor:
|
class JointVelocityProcessor:
|
||||||
"""Add joint velocity information to observations.
|
"""Add joint velocity information to observations.
|
||||||
|
|
||||||
@@ -589,18 +590,18 @@ class InterventionActionProcessor:
|
|||||||
new_transition[TransitionKey.ACTION] = teleop_action_tensor
|
new_transition[TransitionKey.ACTION] = teleop_action_tensor
|
||||||
|
|
||||||
# Handle episode termination
|
# Handle episode termination
|
||||||
if terminate_episode:
|
new_transition[TransitionKey.DONE] = bool(terminate_episode)
|
||||||
new_transition[TransitionKey.DONE] = True
|
new_transition[TransitionKey.REWARD] = float(success)
|
||||||
if success:
|
|
||||||
new_transition[TransitionKey.REWARD] = 1.0
|
|
||||||
|
|
||||||
# Update info with intervention metadata
|
# Update info with intervention metadata
|
||||||
info = new_transition.get(TransitionKey.INFO, {})
|
info = new_transition.get(TransitionKey.INFO, {})
|
||||||
info["is_intervention"] = is_intervention
|
info["is_intervention"] = is_intervention
|
||||||
info["action_intervention"] = new_transition[TransitionKey.ACTION]
|
|
||||||
info["rerecord_episode"] = rerecord_episode
|
info["rerecord_episode"] = rerecord_episode
|
||||||
info["next.success"] = success if terminate_episode else info.get("next.success", False)
|
info["next.success"] = success if terminate_episode else info.get("next.success", False)
|
||||||
new_transition[TransitionKey.INFO] = info
|
new_transition[TransitionKey.INFO] = info
|
||||||
|
new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] = new_transition[
|
||||||
|
TransitionKey.ACTION
|
||||||
|
]
|
||||||
|
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
@@ -797,10 +798,49 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
|||||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||||
transition = env_processor(transition)
|
transition = env_processor(transition)
|
||||||
|
|
||||||
|
if cfg.mode == "record":
|
||||||
|
action_features = teleop_device.action_features
|
||||||
|
features = {
|
||||||
|
"action": action_features,
|
||||||
|
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||||
|
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||||
|
"complementary_info.discrete_penalty": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": ["discrete_penalty"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value in transition[TransitionKey.OBSERVATION].items():
|
||||||
|
if key == "observation.state":
|
||||||
|
features[key] = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": value.squeeze(0).shape,
|
||||||
|
"names": None,
|
||||||
|
}
|
||||||
|
if "image" in key:
|
||||||
|
features[key] = {
|
||||||
|
"dtype": "video",
|
||||||
|
"shape": value.squeeze(0).shape,
|
||||||
|
"names": ["channels", "height", "width"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create dataset
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
cfg.repo_id,
|
||||||
|
cfg.fps,
|
||||||
|
root=cfg.dataset_root,
|
||||||
|
use_videos=True,
|
||||||
|
image_writer_threads=4,
|
||||||
|
image_writer_processes=0,
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
|
||||||
|
episode_idx = 0
|
||||||
episode_step = 0
|
episode_step = 0
|
||||||
episode_start_time = time.perf_counter()
|
episode_start_time = time.perf_counter()
|
||||||
|
|
||||||
while True:
|
while episode_idx < cfg.num_episodes:
|
||||||
step_start_time = time.perf_counter()
|
step_start_time = time.perf_counter()
|
||||||
|
|
||||||
# Get teleoperation action and extra signals
|
# Get teleoperation action and extra signals
|
||||||
@@ -827,14 +867,20 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
|||||||
|
|
||||||
# Extract processed action and metadata
|
# Extract processed action and metadata
|
||||||
processed_action = processed_action_transition[TransitionKey.ACTION]
|
processed_action = processed_action_transition[TransitionKey.ACTION]
|
||||||
action_info = processed_action_transition.get(TransitionKey.INFO, {})
|
|
||||||
terminate_episode = processed_action_transition.get(TransitionKey.DONE, False)
|
terminate_episode = processed_action_transition.get(TransitionKey.DONE, False)
|
||||||
|
|
||||||
# Step environment with processed action
|
# Step environment with processed action
|
||||||
obs, reward, terminated, truncated, info = env.step(processed_action)
|
obs, reward, terminated, truncated, info = env.step(processed_action)
|
||||||
|
|
||||||
|
reward = reward + processed_action_transition[TransitionKey.REWARD]
|
||||||
|
|
||||||
# Process new observation
|
# Process new observation
|
||||||
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
|
complementary_data = {
|
||||||
|
"raw_joint_positions": info.pop("raw_joint_positions"),
|
||||||
|
**processed_action_transition[TransitionKey.COMPLEMENTARY_DATA],
|
||||||
|
}
|
||||||
|
info.update(processed_action_transition[TransitionKey.INFO])
|
||||||
|
|
||||||
transition = create_transition(
|
transition = create_transition(
|
||||||
observation=obs,
|
observation=obs,
|
||||||
action=processed_action,
|
action=processed_action,
|
||||||
@@ -848,14 +894,27 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
|||||||
terminated = transition.get(TransitionKey.DONE, False)
|
terminated = transition.get(TransitionKey.DONE, False)
|
||||||
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
||||||
|
|
||||||
|
if cfg.mode == "record":
|
||||||
|
observations = {k: v.squeeze(0) for k, v in transition[TransitionKey.OBSERVATION].items()}
|
||||||
|
frame = {
|
||||||
|
**observations,
|
||||||
|
"action": transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"],
|
||||||
|
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||||
|
"next.done": np.array([terminated or truncated], dtype=bool),
|
||||||
|
"complementary_info.discrete_penalty": np.array(
|
||||||
|
[transition[TransitionKey.COMPLEMENTARY_DATA]["discrete_penalty"]], dtype=np.float32
|
||||||
|
),
|
||||||
|
}
|
||||||
|
dataset.add_frame(frame, task=cfg.task)
|
||||||
|
|
||||||
episode_step += 1
|
episode_step += 1
|
||||||
|
|
||||||
# Handle episode termination
|
# Handle episode termination
|
||||||
if terminated or truncated or terminate_episode:
|
if terminated or truncated or terminate_episode:
|
||||||
episode_end_reason = "success" if action_info.get("next.success", False) else "terminated"
|
|
||||||
episode_time = time.perf_counter() - episode_start_time
|
episode_time = time.perf_counter() - episode_start_time
|
||||||
print(f"Episode ended ({episode_end_reason}) after {episode_step} steps in {episode_time:.1f}s")
|
logging.info(
|
||||||
print(f"Rerecord episode: {action_info.get('rerecord_episode', False)}")
|
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
|
||||||
|
)
|
||||||
|
|
||||||
# Reset for new episode
|
# Reset for new episode
|
||||||
obs, info = env.reset()
|
obs, info = env.reset()
|
||||||
@@ -867,11 +926,24 @@ def control_loop(env, env_processor, action_processor, teleop_device, cfg: EnvCo
|
|||||||
transition = env_processor(transition)
|
transition = env_processor(transition)
|
||||||
|
|
||||||
episode_step = 0
|
episode_step = 0
|
||||||
episode_start_time = time.perf_counter()
|
episode_idx += 1
|
||||||
|
|
||||||
|
if cfg.mode == "record":
|
||||||
|
if transition[TransitionKey.INFO].get("rerecord_episode", False):
|
||||||
|
logging.info(f"Re-recording episode {episode_idx}")
|
||||||
|
dataset.clear_episode_buffer()
|
||||||
|
episode_idx -= 1
|
||||||
|
else:
|
||||||
|
logging.info(f"Saving episode {episode_idx}")
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
# Maintain fps timing
|
# Maintain fps timing
|
||||||
busy_wait(dt - (time.perf_counter() - step_start_time))
|
busy_wait(dt - (time.perf_counter() - step_start_time))
|
||||||
|
|
||||||
|
if cfg.mode == "record" and cfg.push_to_hub:
|
||||||
|
logging.info("Pushing dataset to hub")
|
||||||
|
dataset.push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def main(cfg: EnvConfig):
|
def main(cfg: EnvConfig):
|
||||||
@@ -914,7 +986,6 @@ def main(cfg: EnvConfig):
|
|||||||
print("Environment processor:", env_processor)
|
print("Environment processor:", env_processor)
|
||||||
print("Action processor:", action_processor)
|
print("Action processor:", action_processor)
|
||||||
|
|
||||||
# Run the control loop
|
|
||||||
control_loop(env, env_processor, action_processor, teleop_device, cfg)
|
control_loop(env, env_processor, action_processor, teleop_device, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user