mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
Replay while loop in sample actions with for loops (#2600)
This commit is contained in:
@@ -812,16 +812,13 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
)
|
)
|
||||||
|
|
||||||
dt = -1.0 / num_steps
|
dt = -1.0 / num_steps
|
||||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
|
||||||
|
|
||||||
x_t = noise
|
x_t = noise
|
||||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
for step in range(num_steps):
|
||||||
while time >= -dt / 2:
|
time = 1.0 + step * dt
|
||||||
expanded_time = time.expand(bsize)
|
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||||
|
|
||||||
# Define a closure function to properly capture expanded_time
|
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
|
||||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
|
||||||
return self.denoise_step(
|
return self.denoise_step(
|
||||||
state=state,
|
state=state,
|
||||||
prefix_pad_masks=prefix_pad_masks,
|
prefix_pad_masks=prefix_pad_masks,
|
||||||
@@ -846,15 +843,11 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
else:
|
else:
|
||||||
v_t = denoise_step_partial_call(x_t)
|
v_t = denoise_step_partial_call(x_t)
|
||||||
|
|
||||||
# Euler step
|
x_t = x_t + dt * v_t
|
||||||
x_t += dt * v_t
|
|
||||||
|
|
||||||
# Record x_t and v_t after Euler step
|
|
||||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||||
|
|
||||||
time += dt
|
|
||||||
|
|
||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
def denoise_step(
|
def denoise_step(
|
||||||
|
|||||||
@@ -787,16 +787,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
)
|
)
|
||||||
|
|
||||||
dt = -1.0 / num_steps
|
dt = -1.0 / num_steps
|
||||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
|
||||||
|
|
||||||
x_t = noise
|
x_t = noise
|
||||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
for step in range(num_steps):
|
||||||
while time >= -dt / 2:
|
time = 1.0 + step * dt
|
||||||
expanded_time = time.expand(bsize)
|
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||||
|
|
||||||
# Define a closure function to properly capture expanded_time
|
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
|
||||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
|
||||||
return self.denoise_step(
|
return self.denoise_step(
|
||||||
prefix_pad_masks=prefix_pad_masks,
|
prefix_pad_masks=prefix_pad_masks,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
@@ -820,15 +817,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
else:
|
else:
|
||||||
v_t = denoise_step_partial_call(x_t)
|
v_t = denoise_step_partial_call(x_t)
|
||||||
|
|
||||||
# Euler step
|
x_t = x_t + dt * v_t
|
||||||
x_t += dt * v_t
|
|
||||||
|
|
||||||
# Record x_t and v_t after Euler step
|
|
||||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||||
|
|
||||||
time += dt
|
|
||||||
|
|
||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
def denoise_step(
|
def denoise_step(
|
||||||
|
|||||||
@@ -783,18 +783,15 @@ class VLAFlowMatching(nn.Module):
|
|||||||
use_cache=self.config.use_cache,
|
use_cache=self.config.use_cache,
|
||||||
fill_kv_cache=True,
|
fill_kv_cache=True,
|
||||||
)
|
)
|
||||||
dt = -1.0 / self.config.num_steps
|
num_steps = self.config.num_steps
|
||||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
dt = -1.0 / num_steps
|
||||||
|
|
||||||
x_t = noise
|
x_t = noise
|
||||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
for step in range(num_steps):
|
||||||
|
time = 1.0 + step * dt
|
||||||
|
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||||
|
|
||||||
while time >= -dt / 2:
|
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||||
expanded_time = time.expand(bsize)
|
|
||||||
|
|
||||||
# Define a closure function to properly capture expanded_time
|
|
||||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
|
||||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
|
||||||
return self.denoise_step(
|
return self.denoise_step(
|
||||||
x_t=input_x_t,
|
x_t=input_x_t,
|
||||||
prefix_pad_masks=prefix_pad_masks,
|
prefix_pad_masks=prefix_pad_masks,
|
||||||
@@ -818,15 +815,11 @@ class VLAFlowMatching(nn.Module):
|
|||||||
else:
|
else:
|
||||||
v_t = denoise_step_partial_call(x_t)
|
v_t = denoise_step_partial_call(x_t)
|
||||||
|
|
||||||
# Euler step
|
x_t = x_t + dt * v_t
|
||||||
x_t += dt * v_t
|
|
||||||
|
|
||||||
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
|
||||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||||
|
|
||||||
time += dt
|
|
||||||
|
|
||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
def denoise_step(
|
def denoise_step(
|
||||||
|
|||||||
Reference in New Issue
Block a user