mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
bug remove
This commit is contained in:
@@ -46,7 +46,6 @@ Note that in both examples, the repo/folder should contain at least `config.json
|
|||||||
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
|
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import concurrent
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
@@ -57,7 +56,7 @@ from copy import deepcopy
|
|||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
import concurrent
|
||||||
import einops
|
import einops
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -157,29 +156,9 @@ def rollout(
|
|||||||
# Infer "task" from attributes of environments.
|
# Infer "task" from attributes of environments.
|
||||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||||
observation = add_envs_task(env, observation)
|
observation = add_envs_task(env, observation)
|
||||||
if step % 100 == 0:
|
|
||||||
import imageio.v2 as imageio
|
|
||||||
|
|
||||||
img = observation["observation.images.image"] # (1, 3, 256, 256)
|
|
||||||
|
|
||||||
if isinstance(img, torch.Tensor):
|
|
||||||
img = img.detach().cpu().numpy()
|
|
||||||
|
|
||||||
# remove batch → (3, 256, 256)
|
|
||||||
img = img[0]
|
|
||||||
|
|
||||||
# transpose → (256, 256, 3)
|
|
||||||
img = np.transpose(img, (1, 2, 0))
|
|
||||||
|
|
||||||
# scale + convert to uint8
|
|
||||||
img = (img * 255).clip(0, 255).astype(np.uint8)
|
|
||||||
|
|
||||||
# now works
|
|
||||||
imageio.imwrite(f"obs_{step:06d}.png", img)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation)
|
||||||
observation["observation.images.image"]
|
observation['observation.images.image']
|
||||||
# Convert to CPU / numpy.
|
# Convert to CPU / numpy.
|
||||||
action = action.to("cpu").numpy()
|
action = action.to("cpu").numpy()
|
||||||
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||||
@@ -198,11 +177,12 @@ def rollout(
|
|||||||
|
|
||||||
# Keep track of which environments are done so far.
|
# Keep track of which environments are done so far.
|
||||||
# done = terminated | truncated | done
|
# done = terminated | truncated | done
|
||||||
# TODO: jadechoghari changed, this is cleaner
|
#TODO: jadechoghari changed, this is cleaner
|
||||||
done = terminated | truncated | done
|
done = terminated | truncated | done
|
||||||
if step + 1 == max_steps:
|
if step + 1 == max_steps:
|
||||||
done = np.ones_like(done, dtype=bool)
|
done = np.ones_like(done, dtype=bool)
|
||||||
|
|
||||||
|
|
||||||
all_actions.append(torch.from_numpy(action))
|
all_actions.append(torch.from_numpy(action))
|
||||||
all_rewards.append(torch.from_numpy(reward))
|
all_rewards.append(torch.from_numpy(reward))
|
||||||
all_dones.append(torch.from_numpy(done))
|
all_dones.append(torch.from_numpy(done))
|
||||||
@@ -398,7 +378,7 @@ def eval_policy(
|
|||||||
# Wait till all video rendering threads are done.
|
# Wait till all video rendering threads are done.
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
# Compile eval info.
|
# Compile eval info.
|
||||||
info = {
|
info = {
|
||||||
"per_episode": [
|
"per_episode": [
|
||||||
@@ -480,22 +460,16 @@ def _compile_episode_data(
|
|||||||
|
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
def set_global_seed(seed):
|
def set_global_seed(seed):
|
||||||
"""Set seed for reproducibility."""
|
"""Set seed for reproducibility."""
|
||||||
import random
|
import random
|
||||||
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
def log_output_dir(out_dir):
|
def log_output_dir(out_dir):
|
||||||
logging.info("Output dir:" + f" {out_dir}")
|
logging.info("Output dir:"+ f" {out_dir}")
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def eval(cfg: EvalPipelineConfig):
|
def eval(cfg: EvalPipelineConfig):
|
||||||
logging.info(pformat(asdict(cfg)))
|
logging.info(pformat(asdict(cfg)))
|
||||||
|
|||||||
Reference in New Issue
Block a user