diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index aaca3d347b..ba6df51070 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,9 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode __version__ = "0.0.4" from .modeling_utils import ModelMixin -from .models.unet import UNetModel -from .models.unet_ldm import UNetLDMModel -from .models.unet_rl import TemporalUNet +from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel from .pipeline_utils import DiffusionPipeline from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 3f0c78b3c6..71e321e111 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide from .unet_grad_tts import UNetGradTTSModel from .unet_ldm import UNetLDMModel from .unet_rl import TemporalUNet +from .unet_sde_score_estimation import NCSNpp diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 4fdffd33a0..28fea5753c 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -5,6 +5,7 @@ import math import torch import torch.nn as nn + try: import einops from einops.layers.torch import Rearrange @@ -104,14 +105,14 @@ class ResidualTemporalBlock(nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): def __init__( - self, - training_horizon, - transition_dim, - cond_dim, - predict_epsilon=False, - clip_denoised=True, - dim=32, - dim_mults=(1, 2, 4, 8), + self, + training_horizon, + transition_dim, + cond_dim, + predict_epsilon=False, + clip_denoised=True, + dim=32, + dim_mults=(1, 2, 4, 8), ): super().__init__() @@ -211,14 +212,14 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): class TemporalValue(nn.Module): def __init__( - self, - horizon, - transition_dim, - cond_dim, - dim=32, - time_dim=None, - out_dim=1, - dim_mults=(1, 2, 4, 8), + self, + horizon, + transition_dim, + cond_dim, + dim=32, + time_dim=None, + out_dim=1, + dim_mults=(1, 2, 4, 8), ): super().__init__() diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py new file mode 100644 index 0000000000..26b4419ea2 --- /dev/null +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -0,0 +1,1051 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + +# helpers functions + + +import functools +import math +import string + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) + + +# Function ported from StyleGAN2 +def get_weight(module, shape, weight_var="weight", kernel_init=None): + """Get/create weight tensor for a convolution or fully-connected layer.""" + + return module.param(weight_var, kernel_init, shape) + + +class Conv2d(nn.Module): + """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" + + def __init__( + self, + in_ch, + out_ch, + kernel, + up=False, + down=False, + resample_kernel=(1, 3, 3, 1), + use_bias=True, + kernel_init=None, + ): + super().__init__() + assert not (up and down) + assert kernel >= 1 and kernel % 2 == 1 + self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) + if kernel_init is not None: + self.weight.data = kernel_init(self.weight.data.shape) + if use_bias: + self.bias = nn.Parameter(torch.zeros(out_ch)) + + self.up = up + self.down = down + self.resample_kernel = resample_kernel + self.kernel = kernel + self.use_bias = use_bias + + def forward(self, x): + if self.up: + x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) + elif self.down: + x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) + else: + x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) + + if self.use_bias: + x = x + self.bias.reshape(1, -1, 1, 1) + + return x + + +def naive_upsample_2d(x, factor=2): + _N, C, H, W = x.shape + x = torch.reshape(x, (-1, C, H, 1, W, 1)) + x = x.repeat(1, 1, 1, factor, 1, factor) + return torch.reshape(x, (-1, C, H * factor, W * factor)) + + +def naive_downsample_2d(x, factor=2): + _N, C, H, W = x.shape + x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) + return torch.mean(x, dim=(3, 5)) + + +def upsample_conv_2d(x, w, k=None, factor=2, gain=1): + """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. + + Padding is performed only once at the beginning, not between the + operations. + The fused op is considerably more efficient than performing the same + calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = + x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or + `[N, H * factor, W * factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Check weight shape. + assert len(w.shape) == 4 + convH = w.shape[2] + convW = w.shape[3] + inC = w.shape[1] + + assert convW == convH + + # Setup filter kernel. + if k is None: + k = [1] * factor + k = _setup_kernel(k) * (gain * (factor**2)) + p = (k.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + + # Determine data dimensions. + stride = [1, 1, factor, factor] + output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) + output_padding = ( + output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, + output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + num_groups = _shape(x, 1) // inC + + # Transpose weights. + w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) + w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) + w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) + + x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) + # Original TF code. + # x = tf.nn.conv2d_transpose( + # x, + # w, + # output_shape=output_shape, + # strides=stride, + # padding='VALID', + # data_format=data_format) + # JAX equivalent + + return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + + +def conv_downsample_2d(x, w, k=None, factor=2, gain=1): + """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. + + Padding is performed only once at the beginning, not between the operations. + The fused op is considerably more efficient than performing the same + calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = + x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or + `[N, H // factor, W // factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + _outC, _inC, convH, convW = w.shape + assert convW == convH + if k is None: + k = [1] * factor + k = _setup_kernel(k) * gain + p = (k.shape[0] - factor) + (convW - 1) + s = [factor, factor] + x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2)) + return F.conv2d(x, w, stride=s, padding=0) + + +def _setup_kernel(k): + k = np.asarray(k, dtype=np.float32) + if k.ndim == 1: + k = np.outer(k, k) + k /= np.sum(k) + assert k.ndim == 2 + assert k.shape[0] == k.shape[1] + return k + + +def _shape(x, dim): + return x.shape[dim] + + +def upsample_2d(x, k=None, factor=2, gain=1): + r"""Upsample a batch of 2D images with the given filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and upsamples each image with the given filter. The filter is normalized so + that + if the input pixels are constant, they will be scaled by the specified + `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded + with + zeros so that its shape is a multiple of the upsampling factor. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + k = _setup_kernel(k) * (gain * (factor**2)) + p = k.shape[0] - factor + return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) + + +def downsample_2d(x, k=None, factor=2, gain=1): + r"""Downsample a batch of 2D images with the given filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and downsamples each image with the given filter. The filter is normalized + so that + if the input pixels are constant, they will be scaled by the specified + `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded + with + zeros so that its shape is a multiple of the downsampling factor. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + k = _setup_kernel(k) * gain + p = k.shape[0] - factor + return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + + +def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): + """1x1 convolution with DDPM initialization.""" + conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + +def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1): + """3x3 convolution with DDPM initialization.""" + conv = nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias + ) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + +conv1x1 = ddpm_conv1x1 +conv3x3 = ddpm_conv3x3 + + +def _einsum(a, b, c, x, y): + einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c)) + return torch.einsum(einsum_str, x, y) + + +def contract_inner(x, y): + """tensordot(x, y, 1).""" + x_chars = list(string.ascii_lowercase[: len(x.shape)]) + y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)]) + y_chars[0] = x_chars[-1] # first axis of y and last of x get summed + out_chars = x_chars[:-1] + y_chars[1:] + return _einsum(x_chars, y_chars, out_chars, x, y) + + +class NIN(nn.Module): + def __init__(self, in_dim, num_units, init_scale=0.1): + super().__init__() + self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) + self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + y = contract_inner(x, self.W) + self.b + return y.permute(0, 3, 1, 2) + + +def get_act(config): + """Get activation functions from the config file.""" + + if config.model.nonlinearity.lower() == "elu": + return nn.ELU() + elif config.model.nonlinearity.lower() == "relu": + return nn.ReLU() + elif config.model.nonlinearity.lower() == "lrelu": + return nn.LeakyReLU(negative_slope=0.2) + elif config.model.nonlinearity.lower() == "swish": + return nn.SiLU() + else: + raise NotImplementedError("activation function does not exist!") + + +def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): + assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 + half_dim = embedding_dim // 2 + # magic number 10000 is from transformers + emb = math.log(max_positions) / (half_dim - 1) + # emb = math.log(2.) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) + # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] + # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = F.pad(emb, (0, 1), mode="constant") + assert emb.shape == (timesteps.shape[0], embedding_dim) + return emb + + +def default_init(scale=1.0): + """The same initialization used in DDPM.""" + scale = 1e-10 if scale == 0 else scale + return variance_scaling(scale, "fan_avg", "uniform") + + +def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): + """Ported from JAX.""" + + def _compute_fans(shape, in_axis=1, out_axis=0): + receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] + fan_in = shape[in_axis] * receptive_field_size + fan_out = shape[out_axis] * receptive_field_size + return fan_in, fan_out + + def init(shape, dtype=dtype, device=device): + fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) + if mode == "fan_in": + denominator = fan_in + elif mode == "fan_out": + denominator = fan_out + elif mode == "fan_avg": + denominator = (fan_in + fan_out) / 2 + else: + raise ValueError("invalid mode for variance scaling initializer: {}".format(mode)) + variance = scale / denominator + if distribution == "normal": + return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) + elif distribution == "uniform": + return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance) + else: + raise ValueError("invalid distribution for variance scaling initializer") + + return init + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__(self, embedding_size=256, scale=1.0): + super().__init__() + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + def forward(self, x): + x_proj = x[:, None] * self.W[None, :] * 2 * np.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + + +class Combine(nn.Module): + """Combine information from skip connections.""" + + def __init__(self, dim1, dim2, method="cat"): + super().__init__() + self.Conv_0 = conv1x1(dim1, dim2) + self.method = method + + def forward(self, x, y): + h = self.Conv_0(x) + if self.method == "cat": + return torch.cat([h, y], dim=1) + elif self.method == "sum": + return h + y + else: + raise ValueError(f"Method {self.method} not recognized.") + + +class AttnBlockpp(nn.Module): + """Channel-wise self-attention block. Modified from DDPM.""" + + def __init__(self, channels, skip_rescale=False, init_scale=0.0): + super().__init__() + self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) + self.NIN_0 = NIN(channels, channels) + self.NIN_1 = NIN(channels, channels) + self.NIN_2 = NIN(channels, channels) + self.NIN_3 = NIN(channels, channels, init_scale=init_scale) + self.skip_rescale = skip_rescale + + def forward(self, x): + B, C, H, W = x.shape + h = self.GroupNorm_0(x) + q = self.NIN_0(h) + k = self.NIN_1(h) + v = self.NIN_2(h) + + w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) + w = torch.reshape(w, (B, H, W, H * W)) + w = F.softmax(w, dim=-1) + w = torch.reshape(w, (B, H, W, H, W)) + h = torch.einsum("bhwij,bcij->bchw", w, v) + h = self.NIN_3(h) + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) + + +class Upsample(nn.Module): + def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_ch = out_ch if out_ch else in_ch + if not fir: + if with_conv: + self.Conv_0 = conv3x3(in_ch, out_ch) + else: + if with_conv: + self.Conv2d_0 = Conv2d( + in_ch, + out_ch, + kernel=3, + up=True, + resample_kernel=fir_kernel, + use_bias=True, + kernel_init=default_init(), + ) + self.fir = fir + self.with_conv = with_conv + self.fir_kernel = fir_kernel + self.out_ch = out_ch + + def forward(self, x): + B, C, H, W = x.shape + if not self.fir: + h = F.interpolate(x, (H * 2, W * 2), "nearest") + if self.with_conv: + h = self.Conv_0(h) + else: + if not self.with_conv: + h = upsample_2d(x, self.fir_kernel, factor=2) + else: + h = self.Conv2d_0(x) + + return h + + +class Downsample(nn.Module): + def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_ch = out_ch if out_ch else in_ch + if not fir: + if with_conv: + self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) + else: + if with_conv: + self.Conv2d_0 = Conv2d( + in_ch, + out_ch, + kernel=3, + down=True, + resample_kernel=fir_kernel, + use_bias=True, + kernel_init=default_init(), + ) + self.fir = fir + self.fir_kernel = fir_kernel + self.with_conv = with_conv + self.out_ch = out_ch + + def forward(self, x): + B, C, H, W = x.shape + if not self.fir: + if self.with_conv: + x = F.pad(x, (0, 1, 0, 1)) + x = self.Conv_0(x) + else: + x = F.avg_pool2d(x, 2, stride=2) + else: + if not self.with_conv: + x = downsample_2d(x, self.fir_kernel, factor=2) + else: + x = self.Conv2d_0(x) + + return x + + +class ResnetBlockDDPMpp(nn.Module): + """ResBlock adapted from DDPM.""" + + def __init__( + self, + act, + in_ch, + out_ch=None, + temb_dim=None, + conv_shortcut=False, + dropout=0.1, + skip_rescale=False, + init_scale=0.0, + ): + super().__init__() + out_ch = out_ch if out_ch else in_ch + self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) + self.Conv_0 = conv3x3(in_ch, out_ch) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) + nn.init.zeros_(self.Dense_0.bias) + self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) + if in_ch != out_ch: + if conv_shortcut: + self.Conv_2 = conv3x3(in_ch, out_ch) + else: + self.NIN_0 = NIN(in_ch, out_ch) + + self.skip_rescale = skip_rescale + self.act = act + self.out_ch = out_ch + self.conv_shortcut = conv_shortcut + + def forward(self, x, temb=None): + h = self.act(self.GroupNorm_0(x)) + h = self.Conv_0(h) + if temb is not None: + h += self.Dense_0(self.act(temb))[:, :, None, None] + h = self.act(self.GroupNorm_1(h)) + h = self.Dropout_0(h) + h = self.Conv_1(h) + if x.shape[1] != self.out_ch: + if self.conv_shortcut: + x = self.Conv_2(x) + else: + x = self.NIN_0(x) + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) + + +class ResnetBlockBigGANpp(nn.Module): + def __init__( + self, + act, + in_ch, + out_ch=None, + temb_dim=None, + up=False, + down=False, + dropout=0.1, + fir=False, + fir_kernel=(1, 3, 3, 1), + skip_rescale=True, + init_scale=0.0, + ): + super().__init__() + + out_ch = out_ch if out_ch else in_ch + self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) + self.up = up + self.down = down + self.fir = fir + self.fir_kernel = fir_kernel + + self.Conv_0 = conv3x3(in_ch, out_ch) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) + nn.init.zeros_(self.Dense_0.bias) + + self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) + if in_ch != out_ch or up or down: + self.Conv_2 = conv1x1(in_ch, out_ch) + + self.skip_rescale = skip_rescale + self.act = act + self.in_ch = in_ch + self.out_ch = out_ch + + def forward(self, x, temb=None): + h = self.act(self.GroupNorm_0(x)) + + if self.up: + if self.fir: + h = upsample_2d(h, self.fir_kernel, factor=2) + x = upsample_2d(x, self.fir_kernel, factor=2) + else: + h = naive_upsample_2d(h, factor=2) + x = naive_upsample_2d(x, factor=2) + elif self.down: + if self.fir: + h = downsample_2d(h, self.fir_kernel, factor=2) + x = downsample_2d(x, self.fir_kernel, factor=2) + else: + h = naive_downsample_2d(h, factor=2) + x = naive_downsample_2d(x, factor=2) + + h = self.Conv_0(h) + # Add bias to each feature map conditioned on the time embedding + if temb is not None: + h += self.Dense_0(self.act(temb))[:, :, None, None] + h = self.act(self.GroupNorm_1(h)) + h = self.Dropout_0(h) + h = self.Conv_1(h) + + if self.in_ch != self.out_ch or self.up or self.down: + x = self.Conv_2(x) + + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) + + +class NCSNpp(nn.Module): + """NCSN++ model""" + + def __init__(self, config): + super().__init__() + self.config = config + self.act = act = get_act(config) + # self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) + + self.nf = nf = config.model.nf + ch_mult = config.model.ch_mult + self.num_res_blocks = num_res_blocks = config.model.num_res_blocks + self.attn_resolutions = attn_resolutions = config.model.attn_resolutions + dropout = config.model.dropout + resamp_with_conv = config.model.resamp_with_conv + self.num_resolutions = num_resolutions = len(ch_mult) + self.all_resolutions = all_resolutions = [config.data.image_size // (2**i) for i in range(num_resolutions)] + + self.conditional = conditional = config.model.conditional # noise-conditional + fir = config.model.fir + fir_kernel = config.model.fir_kernel + self.skip_rescale = skip_rescale = config.model.skip_rescale + self.resblock_type = resblock_type = config.model.resblock_type.lower() + self.progressive = progressive = config.model.progressive.lower() + self.progressive_input = progressive_input = config.model.progressive_input.lower() + self.embedding_type = embedding_type = config.model.embedding_type.lower() + init_scale = config.model.init_scale + assert progressive in ["none", "output_skip", "residual"] + assert progressive_input in ["none", "input_skip", "residual"] + assert embedding_type in ["fourier", "positional"] + combine_method = config.model.progressive_combine.lower() + combiner = functools.partial(Combine, method=combine_method) + + modules = [] + # timestep/noise_level embedding; only for continuous training + if embedding_type == "fourier": + # Gaussian Fourier features embeddings. + assert config.training.continuous, "Fourier features are only used for continuous training." + + modules.append(GaussianFourierProjection(embedding_size=nf, scale=config.model.fourier_scale)) + embed_dim = 2 * nf + + elif embedding_type == "positional": + embed_dim = nf + + else: + raise ValueError(f"embedding type {embedding_type} unknown.") + + if conditional: + modules.append(nn.Linear(embed_dim, nf * 4)) + modules[-1].weight.data = default_init()(modules[-1].weight.shape) + nn.init.zeros_(modules[-1].bias) + modules.append(nn.Linear(nf * 4, nf * 4)) + modules[-1].weight.data = default_init()(modules[-1].weight.shape) + nn.init.zeros_(modules[-1].bias) + + AttnBlock = functools.partial(AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale) + + Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) + + if progressive == "output_skip": + self.pyramid_upsample = Up_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False) + elif progressive == "residual": + pyramid_upsample = functools.partial(Up_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True) + + Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) + + if progressive_input == "input_skip": + self.pyramid_downsample = Down_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False) + elif progressive_input == "residual": + pyramid_downsample = functools.partial(Down_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True) + + if resblock_type == "ddpm": + ResnetBlock = functools.partial( + ResnetBlockDDPMpp, + act=act, + dropout=dropout, + init_scale=init_scale, + skip_rescale=skip_rescale, + temb_dim=nf * 4, + ) + + elif resblock_type == "biggan": + ResnetBlock = functools.partial( + ResnetBlockBigGANpp, + act=act, + dropout=dropout, + fir=fir, + fir_kernel=fir_kernel, + init_scale=init_scale, + skip_rescale=skip_rescale, + temb_dim=nf * 4, + ) + + else: + raise ValueError(f"resblock type {resblock_type} unrecognized.") + + # Downsampling block + + channels = config.data.num_channels + if progressive_input != "none": + input_pyramid_ch = channels + + modules.append(conv3x3(channels, nf)) + hs_c = [nf] + + in_ch = nf + for i_level in range(num_resolutions): + # Residual blocks for this resolution + for i_block in range(num_res_blocks): + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) + in_ch = out_ch + + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + hs_c.append(in_ch) + + if i_level != num_resolutions - 1: + if resblock_type == "ddpm": + modules.append(Downsample(in_ch=in_ch)) + else: + modules.append(ResnetBlock(down=True, in_ch=in_ch)) + + if progressive_input == "input_skip": + modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) + if combine_method == "cat": + in_ch *= 2 + + elif progressive_input == "residual": + modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) + input_pyramid_ch = in_ch + + hs_c.append(in_ch) + + in_ch = hs_c[-1] + modules.append(ResnetBlock(in_ch=in_ch)) + modules.append(AttnBlock(channels=in_ch)) + modules.append(ResnetBlock(in_ch=in_ch)) + + pyramid_ch = 0 + # Upsampling block + for i_level in reversed(range(num_resolutions)): + for i_block in range(num_res_blocks + 1): + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) + in_ch = out_ch + + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + + if progressive != "none": + if i_level == num_resolutions - 1: + if progressive == "output_skip": + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) + pyramid_ch = channels + elif progressive == "residual": + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, in_ch, bias=True)) + pyramid_ch = in_ch + else: + raise ValueError(f"{progressive} is not a valid name.") + else: + if progressive == "output_skip": + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale)) + pyramid_ch = channels + elif progressive == "residual": + modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) + pyramid_ch = in_ch + else: + raise ValueError(f"{progressive} is not a valid name") + + if i_level != 0: + if resblock_type == "ddpm": + modules.append(Upsample(in_ch=in_ch)) + else: + modules.append(ResnetBlock(in_ch=in_ch, up=True)) + + assert not hs_c + + if progressive != "output_skip": + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) + + self.all_modules = nn.ModuleList(modules) + + def forward(self, x, time_cond): + # import ipdb; ipdb.set_trace() + # timestep/noise_level embedding; only for continuous training + modules = self.all_modules + m_idx = 0 + if self.embedding_type == "fourier": + # Gaussian Fourier features embeddings. + used_sigmas = time_cond + temb = modules[m_idx](torch.log(used_sigmas)) + m_idx += 1 + + elif self.embedding_type == "positional": + # Sinusoidal positional embeddings. + timesteps = time_cond + used_sigmas = self.sigmas[time_cond.long()] + temb = get_timestep_embedding(timesteps, self.nf) + + else: + raise ValueError(f"embedding type {self.embedding_type} unknown.") + + if self.conditional: + temb = modules[m_idx](temb) + m_idx += 1 + temb = modules[m_idx](self.act(temb)) + m_idx += 1 + else: + temb = None + + if not self.config.data.centered: + # If input data is in [0, 1] + x = 2 * x - 1.0 + + # Downsampling block + input_pyramid = None + if self.progressive_input != "none": + input_pyramid = x + + hs = [modules[m_idx](x)] + m_idx += 1 + for i_level in range(self.num_resolutions): + # Residual blocks for this resolution + for i_block in range(self.num_res_blocks): + h = modules[m_idx](hs[-1], temb) + m_idx += 1 + if h.shape[-1] in self.attn_resolutions: + h = modules[m_idx](h) + m_idx += 1 + + hs.append(h) + + if i_level != self.num_resolutions - 1: + if self.resblock_type == "ddpm": + h = modules[m_idx](hs[-1]) + m_idx += 1 + else: + h = modules[m_idx](hs[-1], temb) + m_idx += 1 + + if self.progressive_input == "input_skip": + input_pyramid = self.pyramid_downsample(input_pyramid) + h = modules[m_idx](input_pyramid, h) + m_idx += 1 + + elif self.progressive_input == "residual": + input_pyramid = modules[m_idx](input_pyramid) + m_idx += 1 + if self.skip_rescale: + input_pyramid = (input_pyramid + h) / np.sqrt(2.0) + else: + input_pyramid = input_pyramid + h + h = input_pyramid + + hs.append(h) + + h = hs[-1] + h = modules[m_idx](h, temb) + m_idx += 1 + h = modules[m_idx](h) + m_idx += 1 + h = modules[m_idx](h, temb) + m_idx += 1 + + pyramid = None + + # Upsampling block + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) + m_idx += 1 + + if h.shape[-1] in self.attn_resolutions: + h = modules[m_idx](h) + m_idx += 1 + + if self.progressive != "none": + if i_level == self.num_resolutions - 1: + if self.progressive == "output_skip": + pyramid = self.act(modules[m_idx](h)) + m_idx += 1 + pyramid = modules[m_idx](pyramid) + m_idx += 1 + elif self.progressive == "residual": + pyramid = self.act(modules[m_idx](h)) + m_idx += 1 + pyramid = modules[m_idx](pyramid) + m_idx += 1 + else: + raise ValueError(f"{self.progressive} is not a valid name.") + else: + if self.progressive == "output_skip": + pyramid = self.pyramid_upsample(pyramid) + pyramid_h = self.act(modules[m_idx](h)) + m_idx += 1 + pyramid_h = modules[m_idx](pyramid_h) + m_idx += 1 + pyramid = pyramid + pyramid_h + elif self.progressive == "residual": + pyramid = modules[m_idx](pyramid) + m_idx += 1 + if self.skip_rescale: + pyramid = (pyramid + h) / np.sqrt(2.0) + else: + pyramid = pyramid + h + h = pyramid + else: + raise ValueError(f"{self.progressive} is not a valid name") + + if i_level != 0: + if self.resblock_type == "ddpm": + h = modules[m_idx](h) + m_idx += 1 + else: + h = modules[m_idx](h, temb) + m_idx += 1 + + assert not hs + + if self.progressive == "output_skip": + h = pyramid + else: + h = self.act(modules[m_idx](h)) + m_idx += 1 + h = modules[m_idx](h) + m_idx += 1 + + assert m_idx == len(modules) + if self.config.model.scale_by_sigma: + used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) + h = h / used_sigmas + + return h