Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from einops import rearrange | |
| from unik3d.utils.camera import CameraSampler | |
| from unik3d.utils.coordinate import coords_grid | |
| from unik3d.utils.geometric import iou | |
| try: | |
| from splatting import splatting_function | |
| except Exception as e: | |
| splatting_function = None | |
| print( | |
| f"Splatting not available, please install it from github.com/hperrot/splatting" | |
| ) | |
| def fill(self, rgb, mask): | |
| def fill_noise(size, device): | |
| return torch.normal(0, 1.0, size=size, device=device) | |
| def fill_black(size, device): | |
| return -2 * torch.ones(size, device=device, dtype=torch.float32) | |
| def fill_white(size, device): | |
| return 2 * torch.ones(size, device=device, dtype=torch.float32) | |
| def fill_zero(size, device): | |
| return torch.zeros(size, device=device, dtype=torch.float32) | |
| B, C = rgb.shape[:2] | |
| validity_mask = mask.repeat(1, C, 1, 1).bool() | |
| for i in range(B): | |
| filler_fn = np.random.choice([fill_noise, fill_black, fill_white, fill_zero]) | |
| rgb[i][~validity_mask[i]] = filler_fn( | |
| size=rgb[i][~validity_mask[i]].shape, device=rgb.device | |
| ) | |
| return rgb | |
| def augment_camera(self, inputs, camera_sampler): | |
| rgb = inputs["image"] | |
| gt = inputs["depth"].clone() | |
| guidance = inputs[ | |
| "depth_guidance" | |
| ] # from GT if dense/synthetic or from a model's metric output | |
| validity_mask = inputs["validity_mask"].bool() | |
| dtype, device = gt.dtype, gt.device | |
| B, C, H, W = rgb.shape | |
| augmentable_indices = inputs["valid_camera"] & ( | |
| inputs["depth_mask"].reshape(B, -1).float().mean(dim=1) > 0.0 | |
| ) | |
| augment_indices = torch.rand(B, 1, 1, device=device, dtype=dtype) > 0.9 | |
| augment_indices[~augmentable_indices] = False | |
| id_coords = coords_grid(B, H, W, device=device) | |
| # get rescaled depth | |
| augment_indices = augment_indices.reshape(-1) | |
| for i, is_augment in enumerate(augment_indices): | |
| if not is_augment: | |
| continue | |
| pinhole_camera = inputs["camera"][i] | |
| fov = max(pinhole_camera.hfov[0], pinhole_camera.vfov[0]) * 180 / np.pi | |
| ratio = min(70.0 / fov, 1.0) # decrease effect for larger fov | |
| if fov < 40.0: # skips ~5% | |
| augment_indices[i] = False | |
| continue | |
| rgb_i = rgb[i : i + 1] | |
| id_coords_i = id_coords[i : i + 1] | |
| validity_mask_i = validity_mask[i : i + 1] | |
| depth = guidance[i : i + 1] | |
| if (depth < 0.0).any(): | |
| augment_indices[i] = False | |
| continue | |
| depth = depth.sqrt() # why sqrt?? | |
| depth[~validity_mask_i] = depth.max() * 2.0 | |
| fx, fy, cx, cy = pinhole_camera.params[:, :4].unbind(dim=-1) | |
| new_camera = camera_sampler(fx, fy, cx, cy, mult=1.0, ratio=ratio, H=H) | |
| unprojected = pinhole_camera.reconstruct(depth) | |
| projected = new_camera.project(unprojected) | |
| projection_mask = new_camera.projection_mask | |
| overlap_mask = ( | |
| new_camera.overlap_mask | |
| if new_camera.overlap_mask is not None | |
| else torch.ones_like(projection_mask) | |
| ) | |
| mask = validity_mask_i & overlap_mask | |
| # if it is actually going out, we need to remember the regions | |
| # remember when the tengetial distortion was keeping the validaty_mask border after re-warpingi | |
| # need a better way to define overlap class, in case of vortex style if will mask wrong parts... | |
| # also is_collapse does not take into consideration when we have vortex effect, | |
| # how can we avoid vortex in the first place???? | |
| is_collapse = (projected[0, 1, 0, :] >= 0.0).all() | |
| if is_collapse: | |
| projected[~mask.repeat(1, 2, 1, 1)] = id_coords_i[~mask.repeat(1, 2, 1, 1)] | |
| flow = projected - id_coords_i | |
| depth[~mask] = depth.max() * 2.0 | |
| if flow.norm(dim=1).median() / max(H, W) > 0.1: # extreme cases | |
| augment_indices[i] = False | |
| continue | |
| # warp via soft splat | |
| depth_image = torch.cat([rgb_i, guidance[i : i + 1], mask], dim=1) | |
| depth_image = splatting_function( | |
| "softmax", depth_image, flow, -torch.log(1 + depth.clip(0.01)) | |
| ) | |
| rgb_warp = depth_image[:, :3] | |
| validity_mask_i = depth_image[:, -1:] > 0.0 | |
| expanding = validity_mask_i.sum() > validity_mask[i : i + 1].sum() | |
| threshold = 0.7 if expanding else 0.25 | |
| _iou = iou(validity_mask_i, validity_mask[i : i + 1]) | |
| if _iou < threshold: # too strong augmentation, lose most of the image | |
| augment_indices[i] = False | |
| continue | |
| # where it goes out | |
| mask_unwarpable = projection_mask & overlap_mask | |
| inputs["depth_mask"][i] = inputs["depth_mask"][i] & mask_unwarpable.squeeze(0) | |
| # compute new rays, and use the for supervision | |
| rays = new_camera.get_rays(shapes=(1, H, W)) | |
| rays = rearrange(rays, "b c h w -> b (h w) c") | |
| inputs["rays"][i] = torch.where( | |
| rays.isnan().any(dim=-1, keepdim=True), 0.0, rays | |
| )[0] | |
| # update image, camera and validity_mask | |
| inputs["camera"][i] = new_camera | |
| inputs["image"][i] = self.fill(rgb_warp, validity_mask_i)[0] | |
| inputs["validity_mask"][i] = inputs["validity_mask"][i] & mask_unwarpable[0] | |
| # needed to reverse the augmentation for loss-computation (i.e. un-warp the prediction) | |
| inputs["grid_sample"][i] = projected[0] | |
| return inputs | |