mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8fff0fde7c | |||
| 04de496547 | |||
| baf9b50365 | |||
| a0fdbf037a |
@@ -57,7 +57,7 @@ class DatasetReplayConfig:
|
|||||||
repo_id: str
|
repo_id: str
|
||||||
# Episode to replay.
|
# Episode to replay.
|
||||||
episode: int
|
episode: int
|
||||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||||
root: str | Path | None = None
|
root: str | Path | None = None
|
||||||
# Limit the frames per second. By default, uses the policy fps.
|
# Limit the frames per second. By default, uses the policy fps.
|
||||||
fps: int = 30
|
fps: int = 30
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class DatasetConfig:
|
|||||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||||
# datasets are provided.
|
# datasets are provided.
|
||||||
repo_id: str
|
repo_id: str
|
||||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||||
root: str | None = None
|
root: str | None = None
|
||||||
episodes: list[int] | None = None
|
episodes: list[int] | None = None
|
||||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||||
|
|||||||
@@ -664,11 +664,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
for the README).
|
for the README).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||||
will be stored under root/repo_id.
|
root (Path | None, optional): Local directory where the dataset will be downloaded and
|
||||||
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
stored. If set, all dataset files will be stored directly under this path. If not set, the
|
||||||
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
|
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
|
||||||
'~/.cache/huggingface/lerobot'.
|
HF_LEROBOT_HOME environment variable).
|
||||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||||
their episode_index in this list. Defaults to None.
|
their episode_index in this list. Defaults to None.
|
||||||
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
||||||
|
|||||||
@@ -55,10 +55,16 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
|
||||||
within the image size. If None, no cropping is done.
|
backbone. If None, no resizing is done and the original image resolution is used.
|
||||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
|
||||||
mode).
|
(crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
|
||||||
|
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
|
||||||
|
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
|
||||||
|
this is computed automatically. Can also be set directly for legacy configs that use
|
||||||
|
crop-only (without resize). If None and no derivation applies, no cropping is done.
|
||||||
|
crop_is_random: Whether the crop should be random at training time (it's always a center
|
||||||
|
crop in eval mode).
|
||||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||||
`None` means no pretrained weights.
|
`None` means no pretrained weights.
|
||||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||||
@@ -114,7 +120,9 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
vision_backbone: str = "resnet18"
|
vision_backbone: str = "resnet18"
|
||||||
crop_shape: tuple[int, int] | None = (84, 84)
|
resize_shape: tuple[int, int] | None = None
|
||||||
|
crop_ratio: float = 1.0
|
||||||
|
crop_shape: tuple[int, int] | None = None
|
||||||
crop_is_random: bool = True
|
crop_is_random: bool = True
|
||||||
pretrained_backbone_weights: str | None = None
|
pretrained_backbone_weights: str | None = None
|
||||||
use_group_norm: bool = True
|
use_group_norm: bool = True
|
||||||
@@ -175,6 +183,25 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
f"Got {self.noise_scheduler_type}."
|
f"Got {self.noise_scheduler_type}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.resize_shape is not None and (
|
||||||
|
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
|
||||||
|
):
|
||||||
|
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
|
||||||
|
if not (0 < self.crop_ratio <= 1.0):
|
||||||
|
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
|
||||||
|
|
||||||
|
if self.resize_shape is not None:
|
||||||
|
if self.crop_ratio < 1.0:
|
||||||
|
self.crop_shape = (
|
||||||
|
int(self.resize_shape[0] * self.crop_ratio),
|
||||||
|
int(self.resize_shape[1] * self.crop_ratio),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
|
||||||
|
self.crop_shape = None
|
||||||
|
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
|
||||||
|
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
|
||||||
|
|
||||||
# Check that the horizon size and U-Net downsampling is compatible.
|
# Check that the horizon size and U-Net downsampling is compatible.
|
||||||
# U-Net downsamples by 2 with each stage.
|
# U-Net downsamples by 2 with each stage.
|
||||||
downsampling_factor = 2 ** len(self.down_dims)
|
downsampling_factor = 2 ** len(self.down_dims)
|
||||||
@@ -202,13 +229,12 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||||
|
|
||||||
if self.crop_shape is not None:
|
if self.resize_shape is None and self.crop_shape is not None:
|
||||||
for key, image_ft in self.image_features.items():
|
for key, image_ft in self.image_features.items():
|
||||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
|
||||||
f"for `crop_shape` and {image_ft.shape} for "
|
f"for `crop_shape` and {image_ft.shape} for `{key}`."
|
||||||
f"`{key}`."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that all input images have the same shape.
|
# Check that all input images have the same shape.
|
||||||
|
|||||||
@@ -454,12 +454,18 @@ class DiffusionRgbEncoder(nn.Module):
|
|||||||
def __init__(self, config: DiffusionConfig):
|
def __init__(self, config: DiffusionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Set up optional preprocessing.
|
# Set up optional preprocessing.
|
||||||
if config.crop_shape is not None:
|
if config.resize_shape is not None:
|
||||||
|
self.resize = torchvision.transforms.Resize(config.resize_shape)
|
||||||
|
else:
|
||||||
|
self.resize = None
|
||||||
|
|
||||||
|
crop_shape = config.crop_shape
|
||||||
|
if crop_shape is not None:
|
||||||
self.do_crop = True
|
self.do_crop = True
|
||||||
# Always use center crop for eval
|
# Always use center crop for eval
|
||||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
|
||||||
if config.crop_is_random:
|
if config.crop_is_random:
|
||||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
|
||||||
else:
|
else:
|
||||||
self.maybe_random_crop = self.center_crop
|
self.maybe_random_crop = self.center_crop
|
||||||
else:
|
else:
|
||||||
@@ -485,13 +491,16 @@ class DiffusionRgbEncoder(nn.Module):
|
|||||||
|
|
||||||
# Set up pooling and final layers.
|
# Set up pooling and final layers.
|
||||||
# Use a dry run to get the feature map shape.
|
# Use a dry run to get the feature map shape.
|
||||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
# The dummy shape mirrors the runtime preprocessing order: resize -> crop.
|
||||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
|
||||||
# height and width from `config.image_features`.
|
|
||||||
|
|
||||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||||
images_shape = next(iter(config.image_features.values())).shape
|
images_shape = next(iter(config.image_features.values())).shape
|
||||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
if config.crop_shape is not None:
|
||||||
|
dummy_shape_h_w = config.crop_shape
|
||||||
|
elif config.resize_shape is not None:
|
||||||
|
dummy_shape_h_w = config.resize_shape
|
||||||
|
else:
|
||||||
|
dummy_shape_h_w = images_shape[1:]
|
||||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||||
|
|
||||||
@@ -507,7 +516,10 @@ class DiffusionRgbEncoder(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
(B, D) image feature.
|
(B, D) image feature.
|
||||||
"""
|
"""
|
||||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
# Preprocess: resize if configured, then crop if configured.
|
||||||
|
|
||||||
|
if self.resize is not None:
|
||||||
|
x = self.resize(x)
|
||||||
if self.do_crop:
|
if self.do_crop:
|
||||||
if self.training: # noqa: SIM108
|
if self.training: # noqa: SIM108
|
||||||
x = self.maybe_random_crop(x)
|
x = self.maybe_random_crop(x)
|
||||||
|
|||||||
@@ -995,14 +995,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Initialize model without loading weights
|
# Initialize model without loading weights
|
||||||
# Check if dataset_stats were provided in kwargs
|
# Check if dataset_stats were provided in kwargs
|
||||||
if _transformers_available:
|
model = cls(config, **kwargs)
|
||||||
from transformers.modeling_utils import no_init_weights
|
|
||||||
|
|
||||||
with no_init_weights():
|
|
||||||
model = cls(config, **kwargs)
|
|
||||||
model.model.paligemma_with_expert.paligemma.tie_weights()
|
|
||||||
else:
|
|
||||||
model = cls(config, **kwargs)
|
|
||||||
|
|
||||||
# Now manually load and remap the state dict
|
# Now manually load and remap the state dict
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -967,14 +967,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Initialize model without loading weights
|
# Initialize model without loading weights
|
||||||
# Check if dataset_stats were provided in kwargs
|
# Check if dataset_stats were provided in kwargs
|
||||||
if _transformers_available:
|
model = cls(config, **kwargs)
|
||||||
from transformers.modeling_utils import no_init_weights
|
|
||||||
|
|
||||||
with no_init_weights():
|
|
||||||
model = cls(config, **kwargs)
|
|
||||||
model.model.paligemma_with_expert.paligemma.tie_weights()
|
|
||||||
else:
|
|
||||||
model = cls(config, **kwargs)
|
|
||||||
|
|
||||||
# Now manually load and remap the state dict
|
# Now manually load and remap the state dict
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -895,14 +895,7 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Initialize model without loading weights
|
# Initialize model without loading weights
|
||||||
# Check if dataset_stats were provided in kwargs
|
# Check if dataset_stats were provided in kwargs
|
||||||
if _transformers_available:
|
model = cls(config, **kwargs)
|
||||||
from transformers.modeling_utils import no_init_weights
|
|
||||||
|
|
||||||
with no_init_weights():
|
|
||||||
model = cls(config, **kwargs)
|
|
||||||
model.model.paligemma_with_expert.paligemma.tie_weights()
|
|
||||||
else:
|
|
||||||
model = cls(config, **kwargs)
|
|
||||||
|
|
||||||
# Now manually load and remap the state dict
|
# Now manually load and remap the state dict
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -106,6 +106,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
|||||||
# Real-Time Chunking (RTC) configuration
|
# Real-Time Chunking (RTC) configuration
|
||||||
rtc_config: RTCConfig | None = None
|
rtc_config: RTCConfig | None = None
|
||||||
|
|
||||||
|
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||||
|
compile_mode: str = "max-autotune" # Torch compile mode
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
|
|||||||
@@ -593,6 +593,12 @@ class VLAFlowMatching(nn.Module):
|
|||||||
self.prefix_length = self.config.prefix_length
|
self.prefix_length = self.config.prefix_length
|
||||||
self.rtc_processor = rtc_processor
|
self.rtc_processor = rtc_processor
|
||||||
|
|
||||||
|
# Compile model if requested
|
||||||
|
if config.compile_model:
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||||
|
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||||
|
|
||||||
def _rtc_enabled(self):
|
def _rtc_enabled(self):
|
||||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||||
|
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ class DatasetRecordConfig:
|
|||||||
repo_id: str
|
repo_id: str
|
||||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||||
single_task: str
|
single_task: str
|
||||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||||
root: str | Path | None = None
|
root: str | Path | None = None
|
||||||
# Limit the frames per second.
|
# Limit the frames per second.
|
||||||
fps: int = 30
|
fps: int = 30
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class DatasetReplayConfig:
|
|||||||
repo_id: str
|
repo_id: str
|
||||||
# Episode to replay.
|
# Episode to replay.
|
||||||
episode: int
|
episode: int
|
||||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||||
root: str | Path | None = None
|
root: str | Path | None = None
|
||||||
# Limit the frames per second. By default, uses the policy fps.
|
# Limit the frames per second. By default, uses the policy fps.
|
||||||
fps: int = 30
|
fps: int = 30
|
||||||
|
|||||||
@@ -380,10 +380,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use effective batch size for proper epoch calculation in distributed training
|
# Keep global batch size for logging; MetricsTracker handles world size internally.
|
||||||
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
||||||
train_tracker = MetricsTracker(
|
train_tracker = MetricsTracker(
|
||||||
effective_batch_size,
|
cfg.batch_size,
|
||||||
dataset.num_frames,
|
dataset.num_frames,
|
||||||
dataset.num_episodes,
|
dataset.num_episodes,
|
||||||
train_metrics,
|
train_metrics,
|
||||||
|
|||||||
@@ -104,9 +104,10 @@ class MetricsTracker:
|
|||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
|
|
||||||
self.steps = initial_step
|
self.steps = initial_step
|
||||||
|
world_size = accelerator.num_processes if accelerator else 1
|
||||||
# A sample is an (observation,action) pair, where observation and action
|
# A sample is an (observation,action) pair, where observation and action
|
||||||
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
|
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
|
||||||
self.samples = self.steps * self._batch_size
|
self.samples = self.steps * self._batch_size * world_size
|
||||||
self.episodes = self.samples / self._avg_samples_per_ep
|
self.episodes = self.samples / self._avg_samples_per_ep
|
||||||
self.epochs = self.samples / self._num_frames
|
self.epochs = self.samples / self._num_frames
|
||||||
self.accelerator = accelerator
|
self.accelerator = accelerator
|
||||||
@@ -132,7 +133,8 @@ class MetricsTracker:
|
|||||||
Updates metrics that depend on 'step' for one step.
|
Updates metrics that depend on 'step' for one step.
|
||||||
"""
|
"""
|
||||||
self.steps += 1
|
self.steps += 1
|
||||||
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
|
world_size = self.accelerator.num_processes if self.accelerator else 1
|
||||||
|
self.samples += self._batch_size * world_size
|
||||||
self.episodes = self.samples / self._avg_samples_per_ep
|
self.episodes = self.samples / self._avg_samples_per_ep
|
||||||
self.epochs = self.samples / self._num_frames
|
self.epochs = self.samples / self._num_frames
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8
|
oid sha256:54aecbc1af72a4cd5e9261492f5e7601890517516257aacdf2a0ffb3ce281f1b
|
||||||
size 992
|
size 992
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002
|
oid sha256:88a9c3775a2aa1e90a08850521970070a4fcf0f6b82aab43cd8ccc5cf77e0013
|
||||||
size 47424
|
size 47424
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5
|
oid sha256:91a2635e05a75fe187a5081504c5f35ce3417378813fa2deaf9ca4e8200e1819
|
||||||
size 68
|
size 68
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf
|
oid sha256:645bff922ac7bea63ad018ebf77c303c0e4cd2c1c0dc5ef3192865281bef3dc6
|
||||||
size 47424
|
size 47424
|
||||||
|
|||||||
@@ -24,6 +24,11 @@ def mock_metrics():
|
|||||||
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
||||||
|
|
||||||
|
|
||||||
|
class MockAccelerator:
|
||||||
|
def __init__(self, num_processes: int):
|
||||||
|
self.num_processes = num_processes
|
||||||
|
|
||||||
|
|
||||||
def test_average_meter_initialization():
|
def test_average_meter_initialization():
|
||||||
meter = AverageMeter("loss", ":.2f")
|
meter = AverageMeter("loss", ":.2f")
|
||||||
assert meter.name == "loss"
|
assert meter.name == "loss"
|
||||||
@@ -82,6 +87,37 @@ def test_metrics_tracker_step(mock_metrics):
|
|||||||
assert tracker.epochs == tracker.samples / 1000
|
assert tracker.epochs == tracker.samples / 1000
|
||||||
|
|
||||||
|
|
||||||
|
def test_metrics_tracker_initialization_with_accelerator(mock_metrics):
|
||||||
|
tracker = MetricsTracker(
|
||||||
|
batch_size=32,
|
||||||
|
num_frames=1000,
|
||||||
|
num_episodes=50,
|
||||||
|
metrics=mock_metrics,
|
||||||
|
initial_step=10,
|
||||||
|
accelerator=MockAccelerator(num_processes=2),
|
||||||
|
)
|
||||||
|
assert tracker.steps == 10
|
||||||
|
assert tracker.samples == 10 * 32 * 2
|
||||||
|
assert tracker.episodes == tracker.samples / (1000 / 50)
|
||||||
|
assert tracker.epochs == tracker.samples / 1000
|
||||||
|
|
||||||
|
|
||||||
|
def test_metrics_tracker_step_with_accelerator(mock_metrics):
|
||||||
|
tracker = MetricsTracker(
|
||||||
|
batch_size=32,
|
||||||
|
num_frames=1000,
|
||||||
|
num_episodes=50,
|
||||||
|
metrics=mock_metrics,
|
||||||
|
initial_step=5,
|
||||||
|
accelerator=MockAccelerator(num_processes=2),
|
||||||
|
)
|
||||||
|
tracker.step()
|
||||||
|
assert tracker.steps == 6
|
||||||
|
assert tracker.samples == (5 * 32 * 2) + (32 * 2)
|
||||||
|
assert tracker.episodes == tracker.samples / (1000 / 50)
|
||||||
|
assert tracker.epochs == tracker.samples / 1000
|
||||||
|
|
||||||
|
|
||||||
def test_metrics_tracker_getattr(mock_metrics):
|
def test_metrics_tracker_getattr(mock_metrics):
|
||||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||||
assert tracker.loss == mock_metrics["loss"]
|
assert tracker.loss == mock_metrics["loss"]
|
||||||
|
|||||||
Reference in New Issue
Block a user