mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
refactor(pipeline): Remove model card generation and streamline processor methods
- Eliminated the _generate_model_card method from RobotProcessor, which was responsible for generating README.md files from a template. - Updated save_pretrained method to remove model card generation, focusing on serialization of processor definitions and parameters. - Added default implementations for get_config, state_dict, load_state_dict, reset, and feature_contract methods in various processor classes to enhance consistency and usability.
This commit is contained in:
@@ -13,8 +13,6 @@
|
|||||||
title: Cameras
|
title: Cameras
|
||||||
- local: integrate_hardware
|
- local: integrate_hardware
|
||||||
title: Bring Your Own Hardware
|
title: Bring Your Own Hardware
|
||||||
- local: processor_tutorial
|
|
||||||
title: RobotProcessor Pipeline
|
|
||||||
- local: hilserl
|
- local: hilserl
|
||||||
title: Train a Robot with RL
|
title: Train a Robot with RL
|
||||||
- local: hilserl_sim
|
- local: hilserl_sim
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -409,23 +409,6 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
config_filename = kwargs.pop("config_filename", None)
|
config_filename = kwargs.pop("config_filename", None)
|
||||||
self.save_pretrained(destination_path, config_filename=config_filename)
|
self.save_pretrained(destination_path, config_filename=config_filename)
|
||||||
|
|
||||||
def _generate_model_card(self, destination_path: str) -> None:
|
|
||||||
"""Generate README.md from the RobotProcessor model card template."""
|
|
||||||
# Read the template
|
|
||||||
template_path = Path(__file__).parent.parent / "templates" / "robotprocessor_modelcard_template.md"
|
|
||||||
|
|
||||||
if not template_path.exists():
|
|
||||||
# Fallback: if template doesn't exist, skip model card generation
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(template_path) as f:
|
|
||||||
model_card_content = f.read()
|
|
||||||
|
|
||||||
# Write the README.md
|
|
||||||
readme_path = os.path.join(destination_path, "README.md")
|
|
||||||
with open(readme_path, "w") as f:
|
|
||||||
f.write(model_card_content)
|
|
||||||
|
|
||||||
def save_pretrained(self, destination_path: str, config_filename: str | None = None, **kwargs):
|
def save_pretrained(self, destination_path: str, config_filename: str | None = None, **kwargs):
|
||||||
"""Serialize the processor definition and parameters to *destination_path*.
|
"""Serialize the processor definition and parameters to *destination_path*.
|
||||||
|
|
||||||
@@ -500,9 +483,6 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
with open(os.path.join(destination_path, config_filename), "w") as file_pointer:
|
with open(os.path.join(destination_path, config_filename), "w") as file_pointer:
|
||||||
json.dump(config, file_pointer, indent=2)
|
json.dump(config, file_pointer, indent=2)
|
||||||
|
|
||||||
# Generate README.md from template
|
|
||||||
self._generate_model_card(destination_path)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls, source: str, *, config_filename: str | None = None, overrides: dict[str, Any] | None = None
|
cls, source: str, *, config_filename: str | None = None, overrides: dict[str, Any] | None = None
|
||||||
@@ -910,6 +890,21 @@ class ObservationProcessor:
|
|||||||
new_transition[TransitionKey.OBSERVATION] = processed_observation
|
new_transition[TransitionKey.OBSERVATION] = processed_observation
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
class ActionProcessor:
|
class ActionProcessor:
|
||||||
"""Base class for processors that modify only the action component of a transition.
|
"""Base class for processors that modify only the action component of a transition.
|
||||||
@@ -952,6 +947,21 @@ class ActionProcessor:
|
|||||||
new_transition[TransitionKey.ACTION] = processed_action
|
new_transition[TransitionKey.ACTION] = processed_action
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
class RewardProcessor:
|
class RewardProcessor:
|
||||||
"""Base class for processors that modify only the reward component of a transition.
|
"""Base class for processors that modify only the reward component of a transition.
|
||||||
@@ -993,6 +1003,21 @@ class RewardProcessor:
|
|||||||
new_transition[TransitionKey.REWARD] = processed_reward
|
new_transition[TransitionKey.REWARD] = processed_reward
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
class DoneProcessor:
|
class DoneProcessor:
|
||||||
"""Base class for processors that modify only the done flag of a transition.
|
"""Base class for processors that modify only the done flag of a transition.
|
||||||
@@ -1039,6 +1064,21 @@ class DoneProcessor:
|
|||||||
new_transition[TransitionKey.DONE] = processed_done
|
new_transition[TransitionKey.DONE] = processed_done
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
class TruncatedProcessor:
|
class TruncatedProcessor:
|
||||||
"""Base class for processors that modify only the truncated flag of a transition.
|
"""Base class for processors that modify only the truncated flag of a transition.
|
||||||
@@ -1081,6 +1121,21 @@ class TruncatedProcessor:
|
|||||||
new_transition[TransitionKey.TRUNCATED] = processed_truncated
|
new_transition[TransitionKey.TRUNCATED] = processed_truncated
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
class InfoProcessor:
|
class InfoProcessor:
|
||||||
"""Base class for processors that modify only the info dictionary of a transition.
|
"""Base class for processors that modify only the info dictionary of a transition.
|
||||||
@@ -1128,6 +1183,21 @@ class InfoProcessor:
|
|||||||
new_transition[TransitionKey.INFO] = processed_info
|
new_transition[TransitionKey.INFO] = processed_info
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
class ComplementaryDataProcessor:
|
class ComplementaryDataProcessor:
|
||||||
"""Base class for processors that modify only the complementary data of a transition.
|
"""Base class for processors that modify only the complementary data of a transition.
|
||||||
@@ -1156,6 +1226,21 @@ class ComplementaryDataProcessor:
|
|||||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data
|
new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
class IdentityProcessor:
|
class IdentityProcessor:
|
||||||
"""Identity processor that does nothing."""
|
"""Identity processor that does nothing."""
|
||||||
|
|||||||
@@ -1,195 +0,0 @@
|
|||||||
---
|
|
||||||
library_name: lerobot
|
|
||||||
tags:
|
|
||||||
- robotics
|
|
||||||
- lerobot
|
|
||||||
- safetensors
|
|
||||||
pipeline_tag: robotics
|
|
||||||
---
|
|
||||||
|
|
||||||
# RobotProcessor
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
RobotProcessor is a composable, debuggable post-processing pipeline for robot transitions in the LeRobot framework. It orchestrates an ordered collection of small, functional transforms (steps) that are executed left-to-right on each incoming `EnvTransition`.
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
The RobotProcessor provides a modular architecture for processing robot environment transitions through a sequence of composable steps. Each step is a callable that accepts a full `EnvTransition` tuple and returns a potentially modified tuple of the same structure.
|
|
||||||
|
|
||||||
### EnvTransition Structure
|
|
||||||
|
|
||||||
An `EnvTransition` is a 7-tuple containing:
|
|
||||||
|
|
||||||
1. **observation**: Current state observation
|
|
||||||
2. **action**: Action taken (can be None)
|
|
||||||
3. **reward**: Reward received (float or None)
|
|
||||||
4. **done**: Episode termination flag (bool or None)
|
|
||||||
5. **truncated**: Episode truncation flag (bool or None)
|
|
||||||
6. **info**: Additional information dictionary
|
|
||||||
7. **complementary_data**: Extra data dictionary
|
|
||||||
|
|
||||||
## Key Features
|
|
||||||
|
|
||||||
- **Composable Pipeline**: Chain multiple processing steps in a specific order
|
|
||||||
- **State Persistence**: Save and load processor state using SafeTensors format
|
|
||||||
- **Hugging Face Hub Integration**: Easy sharing and loading via `save_pretrained()` and `from_pretrained()`
|
|
||||||
- **Debugging Support**: Step-through functionality to inspect intermediate transformations
|
|
||||||
- **Hook System**: Before/after step hooks for additional processing or monitoring
|
|
||||||
- **Device Support**: Move tensor states to different devices (CPU/GPU)
|
|
||||||
- **Performance Profiling**: Built-in profiling to identify bottlenecks
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
Follow the [installation instructions](https://huggingface.co/docs/lerobot/installation) to install the package.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### Basic Example
|
|
||||||
|
|
||||||
```python
|
|
||||||
from lerobot.processor.pipeline import RobotProcessor
|
|
||||||
from your_steps import ObservationNormalizer, VelocityCalculator
|
|
||||||
|
|
||||||
# Create a processor with multiple steps
|
|
||||||
processor = RobotProcessor(
|
|
||||||
steps=[
|
|
||||||
ObservationNormalizer(mean=0, std=1),
|
|
||||||
VelocityCalculator(window_size=5),
|
|
||||||
],
|
|
||||||
name="my_robot_processor",
|
|
||||||
seed=42
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process a transition
|
|
||||||
obs, info = env.reset()
|
|
||||||
transition = (obs, None, 0.0, False, False, info, {})
|
|
||||||
processed_transition = processor(transition)
|
|
||||||
|
|
||||||
# Extract processed observation
|
|
||||||
processed_obs = processed_transition[0]
|
|
||||||
```
|
|
||||||
|
|
||||||
### Saving and Loading
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Save locally
|
|
||||||
processor.save_pretrained("./my_processor")
|
|
||||||
|
|
||||||
# Push to Hugging Face Hub
|
|
||||||
processor.push_to_hub("username/my-robot-processor")
|
|
||||||
|
|
||||||
# Load from Hub
|
|
||||||
loaded_processor = RobotProcessor.from_pretrained("username/my-robot-processor")
|
|
||||||
```
|
|
||||||
|
|
||||||
### Debugging with Step-Through
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Inspect intermediate results
|
|
||||||
for idx, intermediate_transition in enumerate(processor.step_through(transition)):
|
|
||||||
print(f"After step {idx}: {intermediate_transition[0]}") # Print observation
|
|
||||||
```
|
|
||||||
|
|
||||||
### Using Hooks
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Add monitoring hook
|
|
||||||
def log_observation(step_idx, transition):
|
|
||||||
print(f"Step {step_idx}: obs shape = {transition[0].shape}")
|
|
||||||
return None # Don't modify transition
|
|
||||||
|
|
||||||
processor.register_before_step_hook(log_observation)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Creating Custom Steps
|
|
||||||
|
|
||||||
To create a custom processor step, implement the `ProcessorStep` protocol:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from lerobot.processor.pipeline import ProcessorStepRegistry, EnvTransition
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("my_custom_step")
|
|
||||||
class MyCustomStep:
|
|
||||||
def __init__(self, param1=1.0):
|
|
||||||
self.param1 = param1
|
|
||||||
self.buffer = []
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
||||||
obs, action, reward, done, truncated, info, comp_data = transition
|
|
||||||
# Process observation
|
|
||||||
processed_obs = obs * self.param1
|
|
||||||
return (processed_obs, action, reward, done, truncated, info, comp_data)
|
|
||||||
|
|
||||||
def get_config(self) -> dict:
|
|
||||||
return {"param1": self.param1}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict:
|
|
||||||
# Return only torch.Tensor state
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict) -> None:
|
|
||||||
# Load tensor state
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
# Clear buffers at episode boundaries
|
|
||||||
self.buffer.clear()
|
|
||||||
```
|
|
||||||
|
|
||||||
## Advanced Features
|
|
||||||
|
|
||||||
### Device Management
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Move all tensor states to GPU
|
|
||||||
processor = processor.to("cuda")
|
|
||||||
|
|
||||||
# Move to specific device
|
|
||||||
processor = processor.to(torch.device("cuda:1"))
|
|
||||||
```
|
|
||||||
|
|
||||||
### Performance Profiling
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Profile step execution times
|
|
||||||
profile_results = processor.profile_steps(transition, num_runs=100)
|
|
||||||
for step_name, time_ms in profile_results.items():
|
|
||||||
print(f"{step_name}: {time_ms:.3f} ms")
|
|
||||||
```
|
|
||||||
|
|
||||||
### Processor Slicing
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Get a single step
|
|
||||||
first_step = processor[0]
|
|
||||||
|
|
||||||
# Create a sub-processor with steps 1-3
|
|
||||||
sub_processor = processor[1:4]
|
|
||||||
```
|
|
||||||
|
|
||||||
## Model Card Specifications
|
|
||||||
|
|
||||||
- **Pipeline Tag**: robotics
|
|
||||||
- **Library**: lerobot
|
|
||||||
- **Format**: safetensors
|
|
||||||
- **License**: Apache 2.0
|
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
- Steps must maintain the 7-tuple structure of EnvTransition
|
|
||||||
- All tensor state must be separated from configuration for proper serialization
|
|
||||||
- Steps are executed sequentially (no parallel processing within a single transition)
|
|
||||||
|
|
||||||
## Citation
|
|
||||||
|
|
||||||
If you use RobotProcessor in your research, please cite:
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@misc{cadene2024lerobot,
|
|
||||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
|
||||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
|
||||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
|
||||||
year = {2024}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
@@ -22,7 +22,7 @@ from gymnasium.utils.env_checker import check_env
|
|||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.envs.factory import make_env, make_env_config
|
from lerobot.envs.factory import make_env, make_env_config
|
||||||
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
|
from lerobot.envs.utils import preprocess_observation
|
||||||
from tests.utils import require_env
|
from tests.utils import require_env
|
||||||
|
|
||||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||||
@@ -48,12 +48,7 @@ def test_factory(env_name):
|
|||||||
cfg = make_env_config(env_name)
|
cfg = make_env_config(env_name)
|
||||||
env = make_env(cfg, n_envs=1)
|
env = make_env(cfg, n_envs=1)
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
|
obs = preprocess_observation(obs)
|
||||||
# Process observation using processor
|
|
||||||
obs_processor = RobotProcessor([VanillaObservationProcessor()])
|
|
||||||
transition = (obs, None, None, None, None, None, None)
|
|
||||||
processed_transition = obs_processor(transition)
|
|
||||||
obs = processed_transition[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# test image keys are float32 in range [0,1]
|
# test image keys are float32 in range [0,1]
|
||||||
for key in obs:
|
for key in obs:
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
|||||||
from lerobot.datasets.factory import make_dataset
|
from lerobot.datasets.factory import make_dataset
|
||||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
||||||
from lerobot.envs.factory import make_env, make_env_config
|
from lerobot.envs.factory import make_env, make_env_config
|
||||||
|
from lerobot.envs.utils import preprocess_observation
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
|
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
|
||||||
from lerobot.policies.factory import (
|
from lerobot.policies.factory import (
|
||||||
@@ -39,7 +40,6 @@ from lerobot.policies.factory import (
|
|||||||
)
|
)
|
||||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
|
|
||||||
from lerobot.utils.random_utils import seeded_context
|
from lerobot.utils.random_utils import seeded_context
|
||||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||||
@@ -185,10 +185,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
|||||||
observation, _ = env.reset(seed=train_cfg.seed)
|
observation, _ = env.reset(seed=train_cfg.seed)
|
||||||
|
|
||||||
# apply transform to normalize the observations
|
# apply transform to normalize the observations
|
||||||
obs_processor = RobotProcessor([VanillaObservationProcessor()])
|
observation = preprocess_observation(observation)
|
||||||
transition = (observation, None, None, None, None, None, None)
|
|
||||||
processed_transition = obs_processor(transition)
|
|
||||||
observation = processed_transition[TransitionKey.OBSERVATION]
|
|
||||||
|
|
||||||
# send observation to device/gpu
|
# send observation to device/gpu
|
||||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||||
|
|||||||
Reference in New Issue
Block a user