Compare commits

..

3 Commits

Author SHA1 Message Date
Jade Choghari 62d23b0986 add for rest of policies 2026-02-27 16:32:33 +01:00
Jade Choghari a6a2f3662a Merge branch 'main' into speedup-pi05-launch 2026-02-27 18:12:21 +03:00
Jeremiah Coholich 49444652c6 speedup pi-05 modeling loading by 72s 2026-02-20 15:41:44 -05:00
24 changed files with 61 additions and 221 deletions
-25
View File
@@ -1,25 +0,0 @@
# AI Usage Policy
The LeRobot project welcomes contributions from everyone, and we have a few guidelines regarding AI usage to ensure high code quality, clear communication, and a healthy open-source ecosystem:
- **Please disclose significant AI assistance.** If you used AI tools (e.g., Copilot, Claude, Cursor, ChatGPT) to generate a substantial portion of your code or text, let us know in your PR description. Transparency helps us review your changes more effectively.
- **Own your code (The Human-in-the-Loop).** You must fully understand all the changes you are proposing. If you cannot explain what your AI-assisted code does or how it interacts with LeRobot's broader architecture, please take the time to learn and test it before submitting.
- **Keep issues and discussions focused.** You are welcome to use AI to help draft issues or PR descriptions, but please review and edit them carefully before posting. AI can often be overly verbose; trimming the noise and getting straight to the point helps our maintainers address your needs faster.
Our core maintainers also use AI tools to aid their workflows, but they do so while bringing deep contextual knowledge of the LeRobot codebase to validate the output. We ask all contributors to apply that same level of rigor.
## Remember the Human Maintainers
Please remember that LeRobot is maintained by a dedicated team of humans.
Every discussion, issue, and pull request is read and reviewed by real people. While AI tools can generate thousands of lines of code in seconds, reviewing that code still takes human time and energy. Submitting unverified or low-effort AI output puts an unfair burden on our maintainers.
Today, the quality of the AI output still heavily depends on the developer driving the tool. We ask that you respect our maintainers' time by thoroughly vetting, testing, and refining your submissions.
## AI is Welcome Here
LeRobot operates at the cutting edge of AI and robotics, and many of our maintainers actively embrace AI coding assistants as valuable productivity tools. We are a pro-AI project!
Our reason for having an AI policy is not an anti-AI stance. Rather, it exists to ensure that AI is used to enhance human contributions, not replace them with unverified noise. It's about how the tools are used, not the tools themselves.
We value the unique human insight you bring to the LeRobot community. Let AI empower your workflow, but always let your own judgment take the wheel.
+1 -1
View File
@@ -2,7 +2,7 @@
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md) and our [AI policy](./AI_POLICY.md).
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md).
## Ways to Contribute
+1 -1
View File
@@ -57,7 +57,7 @@ class DatasetReplayConfig:
repo_id: str
# Episode to replay.
episode: int
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int = 30
+1 -1
View File
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.5"
version = "0.4.4"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
dynamic = ["readme"]
license = { text = "Apache-2.0" }
+1 -1
View File
@@ -27,7 +27,7 @@ class DatasetConfig:
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
+5 -5
View File
@@ -664,11 +664,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
for the README).
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset.
root (Path | None, optional): Local directory where the dataset will be downloaded and
stored. If set, all dataset files will be stored directly under this path. If not set, the
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
HF_LEROBOT_HOME environment variable).
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
will be stored under root/repo_id.
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
'~/.cache/huggingface/lerobot'.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
+1 -3
View File
@@ -122,14 +122,12 @@ class DynamixelMotorsBus(SerialMotorsBus):
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
protocol_version: int = PROTOCOL_VERSION,
):
super().__init__(port, motors, calibration)
import dynamixel_sdk as dxl
self.port_handler = dxl.PortHandler(self.port)
self.packet_handler = dxl.PacketHandler(protocol_version)
print(f"Using protocol version {protocol_version}")
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
self.sync_writer = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0)
self._comm_success = dxl.COMM_SUCCESS
-69
View File
@@ -33,58 +33,6 @@
# 2. We can change the value of the MyControlTableKey enums without impacting the client code
# {data_name: (address, size_byte)}
# https://emanual.robotis.com/docs/en/dxl/ax/{MODEL}/#control-table
AX_SERIES_CONTROL_TABLE = {
# EEPROM Area
"Model_Number": (0, 2),
"Firmware_Version": (2, 1),
"ID": (3, 1),
"Baud_Rate": (4, 1),
"Return_Delay_Time": (5, 1),
"CW_Angle_Limit": (6, 2),
"CCW_Angle_Limit": (8, 2),
"Temperature_Limit": (11, 1),
"Min_Voltage_Limit": (12, 1),
"Max_Voltage_Limit": (13, 1),
"Max_Torque": (14, 2),
"Status_Return_Level": (16, 1),
"Alarm_LED": (17, 1),
"Shutdown": (18, 1),
# RAM Area
"Torque_Enable": (24, 1),
"LED": (25, 1),
"CW_Compliance_Margin": (26, 1),
"CCW_Compliance_Margin": (27, 1),
"CW_Compliance_Slope": (28, 1),
"CCW_Compliance_Slope": (29, 1),
"Goal_Position": (30, 2),
"Moving_Speed": (32, 2),
"Torque_Limit": (34, 2),
"Present_Position": (36, 2),
"Present_Speed": (38, 2),
"Present_Load": (40, 2),
"Present_Voltage": (42, 1),
"Present_Temperature": (43, 1),
"Registered": (44, 1),
"Moving": (46, 1),
"Lock": (47, 1),
"Punch": (48, 2),
}
# https://emanual.robotis.com/docs/en/dxl/ax/{MODEL}/#baud-rate4
AX_SERIES_BAUDRATE_TABLE = {
9_600: 207,
19_200: 103,
57_600: 34,
115_200: 16,
200_000: 9,
250_000: 7,
400_000: 4,
500_000: 3,
1_000_000: 1,
}
# {data_name: (address, size_byte)}
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#control-table
X_SERIES_CONTROL_TABLE = {
@@ -166,14 +114,6 @@ X_SERIES_ENCODINGS_TABLE = {
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],
}
# {data_name: size_byte}
AX_SERIES_ENCODINGS_TABLE = {
"Goal_Position": AX_SERIES_CONTROL_TABLE["Goal_Position"][1],
"Moving_Speed": AX_SERIES_CONTROL_TABLE["Moving_Speed"][1],
"Present_Position": AX_SERIES_CONTROL_TABLE["Present_Position"][1],
"Present_Speed": AX_SERIES_CONTROL_TABLE["Present_Speed"][1],
}
MODEL_ENCODING_TABLE = {
"x_series": X_SERIES_ENCODINGS_TABLE,
"xl330-m077": X_SERIES_ENCODINGS_TABLE,
@@ -182,8 +122,6 @@ MODEL_ENCODING_TABLE = {
"xm430-w350": X_SERIES_ENCODINGS_TABLE,
"xm540-w270": X_SERIES_ENCODINGS_TABLE,
"xc430-w150": X_SERIES_ENCODINGS_TABLE,
"ax_series": AX_SERIES_ENCODINGS_TABLE,
"ax-12a": AX_SERIES_ENCODINGS_TABLE,
}
# {model: model_resolution}
@@ -196,8 +134,6 @@ MODEL_RESOLUTION = {
"xm430-w350": 4096,
"xm540-w270": 4096,
"xc430-w150": 4096,
"ax_series": 1024,
"ax-12a": 1024,
}
# {model: model_number}
@@ -209,7 +145,6 @@ MODEL_NUMBER_TABLE = {
"xm430-w350": 1020,
"xm540-w270": 1120,
"xc430-w150": 1070,
"ax-12a": 12,
}
# {model: available_operating_modes}
@@ -231,8 +166,6 @@ MODEL_CONTROL_TABLE = {
"xm430-w350": X_SERIES_CONTROL_TABLE,
"xm540-w270": X_SERIES_CONTROL_TABLE,
"xc430-w150": X_SERIES_CONTROL_TABLE,
"ax_series": AX_SERIES_CONTROL_TABLE,
"ax-12a": AX_SERIES_CONTROL_TABLE,
}
MODEL_BAUDRATE_TABLE = {
@@ -243,8 +176,6 @@ MODEL_BAUDRATE_TABLE = {
"xm430-w350": X_SERIES_BAUDRATE_TABLE,
"xm540-w270": X_SERIES_BAUDRATE_TABLE,
"xc430-w150": X_SERIES_BAUDRATE_TABLE,
"ax_series": AX_SERIES_BAUDRATE_TABLE,
"ax-12a": AX_SERIES_BAUDRATE_TABLE,
}
AVAILABLE_BAUDRATES = [
@@ -55,16 +55,10 @@ class DiffusionConfig(PreTrainedConfig):
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
backbone. If None, no resizing is done and the original image resolution is used.
crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
(crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
this is computed automatically. Can also be set directly for legacy configs that use
crop-only (without resize). If None and no derivation applies, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center
crop in eval mode).
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
@@ -120,9 +114,7 @@ class DiffusionConfig(PreTrainedConfig):
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
resize_shape: tuple[int, int] | None = None
crop_ratio: float = 1.0
crop_shape: tuple[int, int] | None = None
crop_shape: tuple[int, int] | None = (84, 84)
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
@@ -183,25 +175,6 @@ class DiffusionConfig(PreTrainedConfig):
f"Got {self.noise_scheduler_type}."
)
if self.resize_shape is not None and (
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
):
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
if not (0 < self.crop_ratio <= 1.0):
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
if self.resize_shape is not None:
if self.crop_ratio < 1.0:
self.crop_shape = (
int(self.resize_shape[0] * self.crop_ratio),
int(self.resize_shape[1] * self.crop_ratio),
)
else:
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
self.crop_shape = None
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
# Check that the horizon size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims)
@@ -229,12 +202,13 @@ class DiffusionConfig(PreTrainedConfig):
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.resize_shape is None and self.crop_shape is not None:
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for `{key}`."
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Check that all input images have the same shape.
@@ -454,18 +454,12 @@ class DiffusionRgbEncoder(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if config.resize_shape is not None:
self.resize = torchvision.transforms.Resize(config.resize_shape)
else:
self.resize = None
crop_shape = config.crop_shape
if crop_shape is not None:
if config.crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -491,16 +485,13 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy shape mirrors the runtime preprocessing order: resize -> crop.
# The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.image_features`.
# Note: we have a check in the config class to make sure all images have the same shape.
images_shape = next(iter(config.image_features.values())).shape
if config.crop_shape is not None:
dummy_shape_h_w = config.crop_shape
elif config.resize_shape is not None:
dummy_shape_h_w = config.resize_shape
else:
dummy_shape_h_w = images_shape[1:]
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
@@ -516,10 +507,7 @@ class DiffusionRgbEncoder(nn.Module):
Returns:
(B, D) image feature.
"""
# Preprocess: resize if configured, then crop if configured.
if self.resize is not None:
x = self.resize(x)
# Preprocess: maybe crop (if it was set up in the __init__).
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
+8 -1
View File
@@ -995,7 +995,14 @@ class PI0Policy(PreTrainedPolicy):
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
if _transformers_available:
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs)
# Now manually load and remap the state dict
try:
+8 -1
View File
@@ -967,7 +967,14 @@ class PI05Policy(PreTrainedPolicy):
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
if _transformers_available:
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs)
# Now manually load and remap the state dict
try:
@@ -895,7 +895,14 @@ class PI0FastPolicy(PreTrainedPolicy):
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
if _transformers_available:
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs)
# Now manually load and remap the state dict
try:
@@ -106,9 +106,6 @@ class SmolVLAConfig(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
def __post_init__(self):
super().__post_init__()
@@ -593,12 +593,6 @@ class VLAFlowMatching(nn.Module):
self.prefix_length = self.config.prefix_length
self.rtc_processor = rtc_processor
# Compile model if requested
if config.compile_model:
torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
self.forward = torch.compile(self.forward, mode=config.compile_mode)
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
+1 -1
View File
@@ -155,7 +155,7 @@ class DatasetRecordConfig:
repo_id: str
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
# Limit the frames per second.
fps: int = 30
+1 -1
View File
@@ -80,7 +80,7 @@ class DatasetReplayConfig:
repo_id: str
# Episode to replay.
episode: int
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int = 30
+2 -2
View File
@@ -380,10 +380,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
# Keep global batch size for logging; MetricsTracker handles world size internally.
# Use effective batch size for proper epoch calculation in distributed training
effective_batch_size = cfg.batch_size * accelerator.num_processes
train_tracker = MetricsTracker(
cfg.batch_size,
effective_batch_size,
dataset.num_frames,
dataset.num_episodes,
train_metrics,
+2 -4
View File
@@ -104,10 +104,9 @@ class MetricsTracker:
self.metrics = metrics
self.steps = initial_step
world_size = accelerator.num_processes if accelerator else 1
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
self.samples = self.steps * self._batch_size * world_size
self.samples = self.steps * self._batch_size
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
self.accelerator = accelerator
@@ -133,8 +132,7 @@ class MetricsTracker:
Updates metrics that depend on 'step' for one step.
"""
self.steps += 1
world_size = self.accelerator.num_processes if self.accelerator else 1
self.samples += self._batch_size * world_size
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:54aecbc1af72a4cd5e9261492f5e7601890517516257aacdf2a0ffb3ce281f1b
oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8
size 992
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:88a9c3775a2aa1e90a08850521970070a4fcf0f6b82aab43cd8ccc5cf77e0013
oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002
size 47424
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:91a2635e05a75fe187a5081504c5f35ce3417378813fa2deaf9ca4e8200e1819
oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5
size 68
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:645bff922ac7bea63ad018ebf77c303c0e4cd2c1c0dc5ef3192865281bef3dc6
oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf
size 47424
-36
View File
@@ -24,11 +24,6 @@ def mock_metrics():
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
class MockAccelerator:
def __init__(self, num_processes: int):
self.num_processes = num_processes
def test_average_meter_initialization():
meter = AverageMeter("loss", ":.2f")
assert meter.name == "loss"
@@ -87,37 +82,6 @@ def test_metrics_tracker_step(mock_metrics):
assert tracker.epochs == tracker.samples / 1000
def test_metrics_tracker_initialization_with_accelerator(mock_metrics):
tracker = MetricsTracker(
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=mock_metrics,
initial_step=10,
accelerator=MockAccelerator(num_processes=2),
)
assert tracker.steps == 10
assert tracker.samples == 10 * 32 * 2
assert tracker.episodes == tracker.samples / (1000 / 50)
assert tracker.epochs == tracker.samples / 1000
def test_metrics_tracker_step_with_accelerator(mock_metrics):
tracker = MetricsTracker(
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=mock_metrics,
initial_step=5,
accelerator=MockAccelerator(num_processes=2),
)
tracker.step()
assert tracker.steps == 6
assert tracker.samples == (5 * 32 * 2) + (32 * 2)
assert tracker.episodes == tracker.samples / (1000 / 50)
assert tracker.epochs == tracker.samples / 1000
def test_metrics_tracker_getattr(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
assert tracker.loss == mock_metrics["loss"]