mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
+18
-12
@@ -46,6 +46,7 @@ Note that in both examples, the repo/folder should contain at least `config.json
|
||||
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
|
||||
"""
|
||||
|
||||
import concurrent
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
@@ -56,7 +57,7 @@ from copy import deepcopy
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
import concurrent
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@@ -158,27 +159,27 @@ def rollout(
|
||||
observation = add_envs_task(env, observation)
|
||||
if step % 100 == 0:
|
||||
import imageio.v2 as imageio
|
||||
|
||||
|
||||
img = observation["observation.images.image"] # (1, 3, 256, 256)
|
||||
|
||||
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.detach().cpu().numpy()
|
||||
|
||||
|
||||
# remove batch → (3, 256, 256)
|
||||
img = img[0]
|
||||
|
||||
|
||||
# transpose → (256, 256, 3)
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
|
||||
# scale + convert to uint8
|
||||
img = (img * 255).clip(0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
# now works
|
||||
imageio.imwrite(f"obs_{step:06d}.png", img)
|
||||
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
observation['observation.images.image']
|
||||
observation["observation.images.image"]
|
||||
# Convert to CPU / numpy.
|
||||
action = action.to("cpu").numpy()
|
||||
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
@@ -197,12 +198,11 @@ def rollout(
|
||||
|
||||
# Keep track of which environments are done so far.
|
||||
# done = terminated | truncated | done
|
||||
#TODO: jadechoghari changed, this is cleaner
|
||||
# TODO: jadechoghari changed, this is cleaner
|
||||
done = terminated | truncated | done
|
||||
if step + 1 == max_steps:
|
||||
done = np.ones_like(done, dtype=bool)
|
||||
|
||||
|
||||
all_actions.append(torch.from_numpy(action))
|
||||
all_rewards.append(torch.from_numpy(reward))
|
||||
all_dones.append(torch.from_numpy(done))
|
||||
@@ -398,7 +398,7 @@ def eval_policy(
|
||||
# Wait till all video rendering threads are done.
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
# Compile eval info.
|
||||
info = {
|
||||
"per_episode": [
|
||||
@@ -480,16 +480,22 @@ def _compile_episode_data(
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
def set_global_seed(seed):
|
||||
"""Set seed for reproducibility."""
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info("Output dir:"+ f" {out_dir}")
|
||||
logging.info("Output dir:" + f" {out_dir}")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def eval(cfg: EvalPipelineConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
Reference in New Issue
Block a user