This commit is contained in:
Dhruv Nair
2026-05-01 06:14:24 +02:00
parent a5bc04696b
commit 9e689bdc51
4 changed files with 12 additions and 4 deletions

View File

@@ -611,7 +611,9 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
image_latents = self._patchify_latents(image_latents)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
image_latents.device, image_latents.dtype
)
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
return image_latents

View File

@@ -467,7 +467,9 @@ class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
image_latents = self._patchify_latents(image_latents)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
image_latents.device, image_latents.dtype
)
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
return image_latents

View File

@@ -547,7 +547,9 @@ class Flux2KleinInpaintPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
image_latents = self._patchify_latents(image_latents)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
image_latents.device, image_latents.dtype
)
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
return image_latents

View File

@@ -477,7 +477,9 @@ class Flux2KleinKVPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
image_latents = self._patchify_latents(image_latents)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
image_latents.device, image_latents.dtype
)
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
return image_latents