diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 9b6f38ad4..b7ddb4e99 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -812,16 +812,13 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` ) dt = -1.0 / num_steps - dt = torch.tensor(dt, dtype=torch.float32, device=device) x_t = noise - time = torch.tensor(1.0, dtype=torch.float32, device=device) - while time >= -dt / 2: - expanded_time = time.expand(bsize) + for step in range(num_steps): + time = 1.0 + step * dt + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) - # Define a closure function to properly capture expanded_time - # This avoids the lambda expression (E731) and loop variable binding (B023) issues - def denoise_step_partial_call(input_x_t, current_timestep=expanded_time): + def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): return self.denoise_step( state=state, prefix_pad_masks=prefix_pad_masks, @@ -846,15 +843,11 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` else: v_t = denoise_step_partial_call(x_t) - # Euler step - x_t += dt * v_t + x_t = x_t + dt * v_t - # Record x_t and v_t after Euler step if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) - time += dt - return x_t def denoise_step( diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 6500ada20..fde760168 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -787,16 +787,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) dt = -1.0 / num_steps - dt = torch.tensor(dt, dtype=torch.float32, device=device) x_t = noise - time = torch.tensor(1.0, dtype=torch.float32, device=device) - while time >= -dt / 2: - expanded_time = time.expand(bsize) + for step in range(num_steps): + time = 1.0 + step * dt + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) - # Define a closure function to properly capture expanded_time - # This avoids the lambda expression (E731) and loop variable binding (B023) issues - def denoise_step_partial_call(input_x_t, current_timestep=expanded_time): + def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): return self.denoise_step( prefix_pad_masks=prefix_pad_masks, past_key_values=past_key_values, @@ -820,15 +817,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` else: v_t = denoise_step_partial_call(x_t) - # Euler step - x_t += dt * v_t + x_t = x_t + dt * v_t - # Record x_t and v_t after Euler step if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) - time += dt - return x_t def denoise_step( diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index e442b14d5..cce41def8 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -783,18 +783,15 @@ class VLAFlowMatching(nn.Module): use_cache=self.config.use_cache, fill_kv_cache=True, ) - dt = -1.0 / self.config.num_steps - dt = torch.tensor(dt, dtype=torch.float32, device=device) + num_steps = self.config.num_steps + dt = -1.0 / num_steps x_t = noise - time = torch.tensor(1.0, dtype=torch.float32, device=device) + for step in range(num_steps): + time = 1.0 + step * dt + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) - while time >= -dt / 2: - expanded_time = time.expand(bsize) - - # Define a closure function to properly capture expanded_time - # This avoids the lambda expression (E731) and loop variable binding (B023) issues - def denoise_step_partial_call(input_x_t, current_timestep=expanded_time): + def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): return self.denoise_step( x_t=input_x_t, prefix_pad_masks=prefix_pad_masks, @@ -818,15 +815,11 @@ class VLAFlowMatching(nn.Module): else: v_t = denoise_step_partial_call(x_t) - # Euler step - x_t += dt * v_t + x_t = x_t + dt * v_t - # Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step) if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) - time += dt - return x_t def denoise_step(