Fix UniPC scheduler device mismatch when using offloading (#13489)

When model/CPU offloading is enabled, self.sigmas may reside on CPU
while the sample tensor is on GPU. The multistep_uni_p_bh_update and
multistep_uni_c_bh_update methods index self.sigmas without moving
the result to the sample device, causing torch.stack(rks) to fail
with "Expected all tensors to be on the same device".

Move sigma values to the sample device immediately after indexing,
ensuring all derived tensors (lambda, h, rk) stay on the correct
device throughout the computation.

Fixes #13488

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Param
2026-04-30 10:16:20 -05:00
committed by GitHub
parent a5bc04696b
commit 716f246031

View File

@@ -882,7 +882,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
device = sample.device
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1].to(device), self.sigmas[self.step_index].to(device)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
@@ -890,14 +891,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0
device = sample.device
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - i
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si].to(device))
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
@@ -1017,7 +1017,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
x_t = this_sample
model_t = this_model_output
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
device = this_sample.device
sigma_t, sigma_s0 = self.sigmas[self.step_index].to(device), self.sigmas[self.step_index - 1].to(device)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
@@ -1025,14 +1026,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0
device = this_sample.device
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - (i + 1)
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si].to(device))
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)