Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from unik3d.utils.constants import VERBOSE | |
| from unik3d.utils.misc import profile_method | |
| class ResidualConvUnit(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| kernel_size: int = 3, | |
| padding_mode: str = "zeros", | |
| dilation: int = 1, | |
| layer_scale: float = 1.0, | |
| use_norm: bool = False, | |
| ): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d( | |
| dim, | |
| dim, | |
| kernel_size=kernel_size, | |
| padding=dilation * (kernel_size - 1) // 2, | |
| dilation=dilation, | |
| padding_mode=padding_mode, | |
| ) | |
| self.conv2 = nn.Conv2d( | |
| dim, | |
| dim, | |
| kernel_size=kernel_size, | |
| padding=dilation * (kernel_size - 1) // 2, | |
| dilation=dilation, | |
| padding_mode=padding_mode, | |
| ) | |
| self.activation = nn.LeakyReLU() | |
| self.skip_add = nn.quantized.FloatFunctional() | |
| self.gamma = ( | |
| nn.Parameter(layer_scale * torch.ones(1, dim, 1, 1)) | |
| if layer_scale > 0.0 | |
| else 1.0 | |
| ) | |
| self.norm1 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity() | |
| self.norm2 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity() | |
| def forward(self, x): | |
| out = self.activation(x) | |
| out = self.conv1(out) | |
| out = self.norm1(out) | |
| out = self.activation(out) | |
| out = self.conv2(out) | |
| out = self.norm2(out) | |
| return self.skip_add.add(self.gamma * out, x) | |
| class ResUpsampleBil(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_dim, | |
| output_dim: int = None, | |
| num_layers: int = 2, | |
| kernel_size: int = 3, | |
| layer_scale: float = 1.0, | |
| padding_mode: str = "zeros", | |
| use_norm: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| output_dim = output_dim if output_dim is not None else hidden_dim // 2 | |
| self.convs = nn.ModuleList([]) | |
| for _ in range(num_layers): | |
| self.convs.append( | |
| ResidualConvUnit( | |
| hidden_dim, | |
| kernel_size=kernel_size, | |
| layer_scale=layer_scale, | |
| padding_mode=padding_mode, | |
| use_norm=use_norm, | |
| ) | |
| ) | |
| self.up = nn.Sequential( | |
| nn.Conv2d( | |
| hidden_dim, | |
| output_dim, | |
| kernel_size=1, | |
| padding=0, | |
| padding_mode=padding_mode, | |
| ), | |
| nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| for conv in self.convs: | |
| x = conv(x) | |
| x = self.up(x) | |
| return x | |
| class ResUpsample(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_dim, | |
| num_layers: int = 2, | |
| kernel_size: int = 3, | |
| layer_scale: float = 1.0, | |
| padding_mode: str = "zeros", | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.convs = nn.ModuleList([]) | |
| for _ in range(num_layers): | |
| self.convs.append( | |
| ResidualConvUnit( | |
| hidden_dim, | |
| kernel_size=kernel_size, | |
| layer_scale=layer_scale, | |
| padding_mode=padding_mode, | |
| ) | |
| ) | |
| self.up = nn.ConvTranspose2d( | |
| hidden_dim, hidden_dim // 2, kernel_size=2, stride=2, padding=0 | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| for conv in self.convs: | |
| x = conv(x) | |
| x = self.up(x) | |
| return x | |
| class ResUpsampleSH(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_dim, | |
| num_layers: int = 2, | |
| kernel_size: int = 3, | |
| layer_scale: float = 1.0, | |
| padding_mode: str = "zeros", | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.convs = nn.ModuleList([]) | |
| for _ in range(num_layers): | |
| self.convs.append( | |
| ResidualConvUnit( | |
| hidden_dim, | |
| kernel_size=kernel_size, | |
| layer_scale=layer_scale, | |
| padding_mode=padding_mode, | |
| ) | |
| ) | |
| self.up = nn.Sequential( | |
| nn.PixelShuffle(2), | |
| nn.Conv2d( | |
| hidden_dim // 4, | |
| hidden_dim // 2, | |
| kernel_size=3, | |
| padding=1, | |
| padding_mode=padding_mode, | |
| ), | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| for conv in self.convs: | |
| x = conv(x) | |
| x = self.up(x) | |
| return x | |