Spaces:
Runtime error
Runtime error
| """ | |
| Author: Luigi Piccinelli | |
| Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) | |
| """ | |
| from math import tanh | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from timm.models.layers import trunc_normal_ | |
| from unik3d.layers import (MLP, AttentionBlock, AttentionLayer, GradChoker, | |
| PositionEmbeddingSine, ResUpsampleBil) | |
| from unik3d.utils.coordinate import coords_grid | |
| from unik3d.utils.geometric import flat_interpolate | |
| from unik3d.utils.misc import get_params | |
| from unik3d.utils.positional_embedding import generate_fourier_features | |
| from unik3d.utils.sht import rsh_cart_3 | |
| from unik3d.utils.misc import profile_method | |
| def orthonormal_init(num_tokens, dims): | |
| pe = torch.randn(num_tokens, dims) | |
| # Apply Gram-Schmidt process to make the matrix orthonormal | |
| # Awful loop.. | |
| for i in range(num_tokens): | |
| for j in range(i): | |
| pe[i] -= torch.dot(pe[i], pe[j]) * pe[j] | |
| pe[i] = F.normalize(pe[i], p=2, dim=0) | |
| return pe | |
| class ListAdapter(nn.Module): | |
| def __init__(self, input_dims: list[int], hidden_dim: int): | |
| super().__init__() | |
| self.input_adapters = nn.ModuleList([]) | |
| self.num_chunks = len(input_dims) | |
| self.checkpoint = True | |
| for input_dim in input_dims: | |
| self.input_adapters.append(nn.Linear(input_dim, hidden_dim)) | |
| def forward(self, xs: torch.Tensor) -> torch.Tensor: | |
| outs = [self.input_adapters[i](x) for i, x in enumerate(xs)] | |
| return outs | |
| class AngularModule(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| num_heads: int = 8, | |
| expansion: int = 4, | |
| dropout: float = 0.0, | |
| layer_scale: float = 1.0, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.pin_params = 3 | |
| self.deg1_params = 3 | |
| self.deg2_params = 5 | |
| self.deg3_params = 7 | |
| self.num_params = ( | |
| self.pin_params + self.deg1_params + self.deg2_params + self.deg3_params | |
| ) | |
| self.aggregate1 = AttentionBlock( | |
| hidden_dim, | |
| num_heads=num_heads, | |
| expansion=expansion, | |
| dropout=dropout, | |
| layer_scale=layer_scale, | |
| ) | |
| self.aggregate2 = AttentionBlock( | |
| hidden_dim, | |
| num_heads=num_heads, | |
| expansion=expansion, | |
| dropout=dropout, | |
| layer_scale=layer_scale, | |
| ) | |
| self.latents_pos = nn.Parameter( | |
| torch.randn(1, self.num_params, hidden_dim), requires_grad=True | |
| ) | |
| self.in_features = nn.Identity() | |
| self.project_pin = nn.Linear( | |
| hidden_dim, self.pin_params * hidden_dim, bias=False | |
| ) | |
| self.project_deg1 = nn.Linear( | |
| hidden_dim, self.deg1_params * hidden_dim, bias=False | |
| ) | |
| self.project_deg2 = nn.Linear( | |
| hidden_dim, self.deg2_params * hidden_dim, bias=False | |
| ) | |
| self.project_deg3 = nn.Linear( | |
| hidden_dim, self.deg3_params * hidden_dim, bias=False | |
| ) | |
| self.out_pinhole = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=1) | |
| self.out_deg1 = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=3) | |
| self.out_deg2 = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=3) | |
| self.out_deg3 = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=3) | |
| def fill_intrinsics(self, x): | |
| hfov, cx, cy = x.unbind(dim=-1) | |
| hfov = torch.sigmoid(hfov - 1.1) # 1.1 magic number s.t hfov = pi/2 for x=0 | |
| ratio = self.shapes[0] / self.shapes[1] | |
| vfov = hfov * ratio | |
| cx = torch.sigmoid(cx) | |
| cy = torch.sigmoid(cy) | |
| correction_tensor = torch.tensor( | |
| [2 * torch.pi, 2 * torch.pi, self.shapes[1], self.shapes[0]], | |
| device=x.device, | |
| dtype=x.dtype, | |
| ) | |
| intrinsics = torch.stack([hfov, vfov, cx, cy], dim=1) | |
| intrinsics = correction_tensor.unsqueeze(0) * intrinsics | |
| return intrinsics | |
| def forward(self, cls_tokens) -> torch.Tensor: | |
| latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1) | |
| pin_tokens, deg1_tokens, deg2_tokens, deg3_tokens = cls_tokens.chunk(4, dim=1) | |
| pin_tokens = rearrange( | |
| self.project_pin(pin_tokens), "b n (h c) -> b (n h) c", h=self.pin_params | |
| ) | |
| deg1_tokens = rearrange( | |
| self.project_deg1(deg1_tokens), "b n (h c) -> b (n h) c", h=self.deg1_params | |
| ) | |
| deg2_tokens = rearrange( | |
| self.project_deg2(deg2_tokens), "b n (h c) -> b (n h) c", h=self.deg2_params | |
| ) | |
| deg3_tokens = rearrange( | |
| self.project_deg3(deg3_tokens), "b n (h c) -> b (n h) c", h=self.deg3_params | |
| ) | |
| tokens = torch.cat([pin_tokens, deg1_tokens, deg2_tokens, deg3_tokens], dim=1) | |
| tokens = self.aggregate1(tokens, pos_embed=latents_pos) | |
| tokens = self.aggregate2(tokens, pos_embed=latents_pos) | |
| tokens_pinhole, tokens_deg1, tokens_deg2, tokens_deg3 = torch.split( | |
| tokens, | |
| [self.pin_params, self.deg1_params, self.deg2_params, self.deg3_params], | |
| dim=1, | |
| ) | |
| x = self.out_pinhole(tokens_pinhole).squeeze(-1) | |
| d1 = self.out_deg1(tokens_deg1) | |
| d2 = self.out_deg2(tokens_deg2) | |
| d3 = self.out_deg3(tokens_deg3) | |
| camera_intrinsics = self.fill_intrinsics(x) | |
| return camera_intrinsics, torch.cat([d1, d2, d3], dim=1) | |
| def set_shapes(self, shapes: tuple[int, int]): | |
| self.shapes = shapes | |
| class RadialModule(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| num_heads: int = 8, | |
| expansion: int = 4, | |
| depths: int | list[int] = 4, | |
| camera_dim: int = 256, | |
| dropout: float = 0.0, | |
| kernel_size: int = 7, | |
| layer_scale: float = 1.0, | |
| out_dim: int = 1, | |
| num_prompt_blocks: int = 1, | |
| use_norm: bool = False, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__() | |
| self.camera_dim = camera_dim | |
| self.out_dim = out_dim | |
| self.hidden_dim = hidden_dim | |
| self.ups = nn.ModuleList([]) | |
| self.depth_mlp = nn.ModuleList([]) | |
| self.process_features = nn.ModuleList([]) | |
| self.project_features = nn.ModuleList([]) | |
| self.out = nn.ModuleList([]) | |
| self.prompt_camera = nn.ModuleList([]) | |
| mult = 2 | |
| self.to_latents = nn.Linear(hidden_dim, hidden_dim) | |
| for _ in range(4): | |
| self.prompt_camera.append( | |
| AttentionLayer( | |
| num_blocks=num_prompt_blocks, | |
| dim=hidden_dim, | |
| num_heads=num_heads, | |
| expansion=expansion, | |
| dropout=dropout, | |
| layer_scale=-1.0, | |
| context_dim=hidden_dim, | |
| ) | |
| ) | |
| for i, depth in enumerate(depths): | |
| current_dim = min(hidden_dim, mult * hidden_dim // int(2**i)) | |
| next_dim = mult * hidden_dim // int(2 ** (i + 1)) | |
| output_dim = max(next_dim, out_dim) | |
| self.process_features.append( | |
| nn.ConvTranspose2d( | |
| hidden_dim, | |
| current_dim, | |
| kernel_size=max(1, 2 * i), | |
| stride=max(1, 2 * i), | |
| padding=0, | |
| ) | |
| ) | |
| self.ups.append( | |
| ResUpsampleBil( | |
| current_dim, | |
| output_dim=output_dim, | |
| expansion=expansion, | |
| layer_scale=layer_scale, | |
| kernel_size=kernel_size, | |
| num_layers=depth, | |
| use_norm=use_norm, | |
| ) | |
| ) | |
| depth_mlp = ( | |
| nn.Sequential(nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim)) | |
| if i == len(depths) - 1 | |
| else nn.Identity() | |
| ) | |
| self.depth_mlp.append(depth_mlp) | |
| self.confidence_mlp = nn.Sequential( | |
| nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim) | |
| ) | |
| self.to_depth_lr = nn.Conv2d( | |
| output_dim, | |
| output_dim // 2, | |
| kernel_size=3, | |
| padding=1, | |
| padding_mode="reflect", | |
| ) | |
| self.to_confidence_lr = nn.Conv2d( | |
| output_dim, | |
| output_dim // 2, | |
| kernel_size=3, | |
| padding=1, | |
| padding_mode="reflect", | |
| ) | |
| self.to_depth_hr = nn.Sequential( | |
| nn.Conv2d( | |
| output_dim // 2, 32, kernel_size=3, padding=1, padding_mode="reflect" | |
| ), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(32, 1, kernel_size=1), | |
| ) | |
| self.to_confidence_hr = nn.Sequential( | |
| nn.Conv2d( | |
| output_dim // 2, 32, kernel_size=3, padding=1, padding_mode="reflect" | |
| ), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(32, 1, kernel_size=1), | |
| ) | |
| def set_original_shapes(self, shapes: tuple[int, int]): | |
| self.original_shapes = shapes | |
| def set_shapes(self, shapes: tuple[int, int]): | |
| self.shapes = shapes | |
| def embed_rays(self, rays): | |
| rays_embedding = flat_interpolate( | |
| rays, old=self.original_shapes, new=self.shapes, antialias=True | |
| ) | |
| rays_embedding = rays_embedding / torch.norm( | |
| rays_embedding, dim=-1, keepdim=True | |
| ).clip(min=1e-4) | |
| x, y, z = rays_embedding[..., 0], rays_embedding[..., 1], rays_embedding[..., 2] | |
| polar = torch.acos(z) | |
| x_clipped = x.abs().clip(min=1e-3) * (2 * (x >= 0).int() - 1) | |
| azimuth = torch.atan2(y, x_clipped) | |
| rays_embedding = torch.stack([polar, azimuth], dim=-1) | |
| rays_embedding = generate_fourier_features( | |
| rays_embedding, | |
| dim=self.hidden_dim, | |
| max_freq=max(self.shapes) // 2, | |
| use_log=True, | |
| cat_orig=False, | |
| ) | |
| return rays_embedding | |
| def condition(self, feat, rays_embeddings): | |
| conditioned_features = [ | |
| prompter(rearrange(feature, "b h w c -> b (h w) c"), rays_embeddings) | |
| for prompter, feature in zip(self.prompt_camera, feat) | |
| ] | |
| return conditioned_features | |
| def process(self, features_list, rays_embeddings): | |
| conditioned_features = self.condition(features_list, rays_embeddings) | |
| init_latents = self.to_latents(conditioned_features[0]) | |
| init_latents = rearrange( | |
| init_latents, "b (h w) c -> b c h w", h=self.shapes[0], w=self.shapes[1] | |
| ).contiguous() | |
| conditioned_features = [ | |
| rearrange( | |
| x, "b (h w) c -> b c h w", h=self.shapes[0], w=self.shapes[1] | |
| ).contiguous() | |
| for x in conditioned_features | |
| ] | |
| latents = init_latents | |
| out_features = [] | |
| for i, up in enumerate(self.ups): | |
| latents = latents + self.process_features[i](conditioned_features[i + 1]) | |
| latents = up(latents) | |
| out_features.append(latents) | |
| return out_features, init_latents | |
| def depth_proj(self, out_features): | |
| depths = [] | |
| h_out, w_out = out_features[-1].shape[-2:] | |
| # aggregate output and project to depth | |
| for i, (layer, features) in enumerate(zip(self.depth_mlp, out_features)): | |
| out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
| if i < len(self.depth_mlp) - 1: | |
| continue | |
| depths.append(out_depth_features) | |
| out_depth_features = F.interpolate( | |
| out_depth_features, size=(h_out, w_out), mode="bilinear", align_corners=True | |
| ) | |
| logdepth = self.to_depth_lr(out_depth_features) | |
| logdepth = F.interpolate( | |
| logdepth, size=self.original_shapes, mode="bilinear", align_corners=True | |
| ) | |
| logdepth = self.to_depth_hr(logdepth) | |
| return logdepth | |
| def confidence_proj(self, out_features): | |
| highres_features = out_features[-1].permute(0, 2, 3, 1) | |
| confidence = self.confidence_mlp(highres_features).permute(0, 3, 1, 2) | |
| confidence = self.to_confidence_lr(confidence) | |
| confidence = F.interpolate( | |
| confidence, size=self.original_shapes, mode="bilinear", align_corners=True | |
| ) | |
| confidence = self.to_confidence_hr(confidence) | |
| return confidence | |
| def decode(self, out_features): | |
| logdepth = self.depth_proj(out_features) | |
| confidence = self.confidence_proj(out_features) | |
| return logdepth, confidence | |
| def forward( | |
| self, | |
| features: list[torch.Tensor], | |
| rays_hr: torch.Tensor, | |
| pos_embed, | |
| level_embed, | |
| ) -> torch.Tensor: | |
| rays_embeddings = self.embed_rays(rays_hr) | |
| features, lowres_features = self.process(features, rays_embeddings) | |
| logdepth, logconf = self.decode(features) | |
| return logdepth, logconf, lowres_features | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| config, | |
| ): | |
| super().__init__() | |
| self.build(config) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Conv2d): | |
| trunc_normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| if m.weight is not None: | |
| nn.init.constant_(m.weight, 1.0) | |
| def run_camera(self, cls_tokens, original_shapes, rays_gt): | |
| H, W = original_shapes | |
| # camera layer | |
| intrinsics, sh_coeffs = self.angular_module(cls_tokens=cls_tokens) | |
| B, N = intrinsics.shape | |
| device = intrinsics.device | |
| dtype = intrinsics.dtype | |
| id_coords = coords_grid(B, H, W, device=sh_coeffs.device) | |
| # This is fov based | |
| longitude = ( | |
| (id_coords[:, 0] - intrinsics[:, 2].view(-1, 1, 1)) | |
| / W | |
| * intrinsics[:, 0].view(-1, 1, 1) | |
| ) | |
| latitude = ( | |
| (id_coords[:, 1] - intrinsics[:, 3].view(-1, 1, 1)) | |
| / H | |
| * intrinsics[:, 1].view(-1, 1, 1) | |
| ) | |
| x = torch.cos(latitude) * torch.sin(longitude) | |
| z = torch.cos(latitude) * torch.cos(longitude) | |
| y = -torch.sin(latitude) | |
| unit_sphere = torch.stack([x, y, z], dim=-1) | |
| unit_sphere = unit_sphere / torch.norm(unit_sphere, dim=-1, keepdim=True).clip( | |
| min=1e-5 | |
| ) | |
| harmonics = rsh_cart_3(unit_sphere)[..., 1:] # remove constant-value harmonic | |
| rays_pred = torch.einsum("bhwc,bcd->bhwd", harmonics, sh_coeffs) | |
| rays_pred = rays_pred / torch.norm(rays_pred, dim=-1, keepdim=True).clip( | |
| min=1e-5 | |
| ) | |
| rays_pred = rays_pred.permute(0, 3, 1, 2) | |
| ### LEGACY CODE for training | |
| # if self.training: | |
| # prob = 1 - tanh(self.steps / self.num_steps) | |
| # where_use_gt_rays = torch.rand(B, 1, 1, device=device, dtype=dtype) < prob | |
| # where_use_gt_rays = where_use_gt_rays.int() | |
| # rays = rays_gt * where_use_gt_rays + rays_pred * (1 - where_use_gt_rays) | |
| # should clean also nans | |
| if self.training: | |
| rays = rays_pred | |
| else: | |
| rays = rays_gt if rays_gt is not None else rays_pred | |
| rays = rearrange(rays, "b c h w -> b (h w) c") | |
| return intrinsics, rays | |
| def forward(self, inputs, image_metas) -> torch.Tensor: | |
| B, C, H, W = inputs["image"].shape | |
| device = inputs["image"].device | |
| rays_gt = inputs.get("rays", None) | |
| # get features in b n d format | |
| common_shape = inputs["features"][0].shape[1:3] | |
| # input shapes repeat shapes for each level, times the amount of the layers: | |
| features = self.input_adapter(inputs["features"]) | |
| # positional embeddings, spatial and level | |
| level_embed = self.level_embeds.repeat( | |
| B, common_shape[0] * common_shape[1], 1, 1 | |
| ) | |
| level_embed = rearrange(level_embed, "b n l d -> b (n l) d") | |
| dummy_tensor = torch.zeros( | |
| B, 1, common_shape[0], common_shape[1], device=device, requires_grad=False | |
| ) | |
| pos_embed = self.pos_embed(dummy_tensor) | |
| pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat(1, 4, 1) | |
| # get cls tokens projections | |
| camera_tokens = inputs["tokens"] | |
| camera_tokens = [self.choker(x.contiguous()) for x in camera_tokens] | |
| camera_tokens = self.camera_token_adapter(camera_tokens) | |
| self.angular_module.set_shapes((H, W)) | |
| intrinsics, rays = self.run_camera( | |
| torch.cat(camera_tokens, dim=1), | |
| original_shapes=(H, W), | |
| rays_gt=rays_gt, | |
| ) | |
| # run bulk of the model | |
| self.radial_module.set_shapes(common_shape) | |
| self.radial_module.set_original_shapes((H, W)) | |
| logradius, logconfidence, lowres_features = self.radial_module( | |
| features=features, | |
| rays_hr=rays, | |
| pos_embed=pos_embed, | |
| level_embed=level_embed, | |
| ) | |
| radius = torch.exp(logradius.clip(min=-8.0, max=8.0) + 2.0) | |
| confidence = torch.exp(logconfidence.clip(min=-8.0, max=10.0)) | |
| outputs = { | |
| "distance": radius, | |
| "lowres_features": lowres_features, | |
| "confidence": confidence, | |
| "K": intrinsics, | |
| "rays": rays, | |
| } | |
| return outputs | |
| def no_weight_decay_keywords(self): | |
| return {"latents_pos", "level_embeds"} | |
| def get_params(self, lr, wd): | |
| angles_p, _ = get_params(self.angular_module, lr, wd) | |
| radius_p, _ = get_params(self.radial_module, lr, wd) | |
| tokens_p, _ = get_params(self.camera_token_adapter, lr, wd) | |
| input_p, _ = get_params(self.input_adapter, lr, wd) | |
| return [*tokens_p, *angles_p, *input_p, *radius_p] | |
| def build(self, config): | |
| input_dims = config["model"]["pixel_encoder"]["embed_dims"] | |
| hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"] | |
| expansion = config["model"]["expansion"] | |
| num_heads = config["model"]["num_heads"] | |
| dropout = config["model"]["pixel_decoder"]["dropout"] | |
| layer_scale = config["model"]["layer_scale"] | |
| depth = config["model"]["pixel_decoder"]["depths"] | |
| depths_encoder = config["model"]["pixel_encoder"]["depths"] | |
| out_dim = config["model"]["pixel_decoder"]["out_dim"] | |
| kernel_size = config["model"]["pixel_decoder"]["kernel_size"] | |
| self.slices_encoder = list(zip([d - 1 for d in depths_encoder], depths_encoder)) | |
| input_dims = [input_dims[d - 1] for d in depths_encoder] | |
| self.steps = 0 | |
| self.num_steps = config["model"].get("num_steps", 100000) | |
| camera_dims = input_dims | |
| self.choker = GradChoker(config["model"]["pixel_decoder"]["detach"]) | |
| self.input_adapter = ListAdapter(input_dims, hidden_dim) | |
| self.camera_token_adapter = ListAdapter(camera_dims, hidden_dim) | |
| self.angular_module = AngularModule( | |
| hidden_dim=hidden_dim, | |
| num_heads=num_heads, | |
| expansion=expansion, | |
| dropout=dropout, | |
| layer_scale=layer_scale, | |
| ) | |
| self.radial_module = RadialModule( | |
| hidden_dim=hidden_dim, | |
| num_heads=num_heads, | |
| expansion=expansion, | |
| depths=depth, | |
| dropout=dropout, | |
| camera_dim=96, | |
| layer_scale=layer_scale, | |
| out_dim=out_dim, | |
| kernel_size=kernel_size, | |
| num_prompt_blocks=config["model"]["pixel_decoder"]["num_prompt_blocks"], | |
| use_norm=False, | |
| ) | |
| self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True) | |
| self.level_embeds = nn.Parameter( | |
| orthonormal_init(len(input_dims), hidden_dim).reshape( | |
| 1, 1, len(input_dims), hidden_dim | |
| ), | |
| requires_grad=False, | |
| ) | |