Spaces:
Runtime error
Runtime error
| """ | |
| Author: Luigi Piccinelli | |
| Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| from math import ceil | |
| import huggingface_hub | |
| import torch.nn.functional as F | |
| import torch.onnx | |
| from unik3d.models.unik3d import UniK3D | |
| class UniK3DONNX(UniK3D): | |
| def __init__( | |
| self, | |
| config, | |
| eps: float = 1e-6, | |
| **kwargs, | |
| ): | |
| super().__init__(config, eps) | |
| def forward(self, rgbs): | |
| B, _, H, W = rgbs.shape | |
| features, tokens = self.pixel_encoder(rgbs) | |
| inputs = {} | |
| inputs["image"] = rgbs | |
| inputs["features"] = [ | |
| self.stacking_fn(features[i:j]).contiguous() | |
| for i, j in self.slices_encoder_range | |
| ] | |
| inputs["tokens"] = [ | |
| self.stacking_fn(tokens[i:j]).contiguous() | |
| for i, j in self.slices_encoder_range | |
| ] | |
| outputs = self.pixel_decoder(inputs, []) | |
| outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W) | |
| pts_3d = outputs["rays"] * outputs["radius"] | |
| return pts_3d, outputs["confidence"] | |
| class UniK3DONNXcam(UniK3D): | |
| def __init__( | |
| self, | |
| config, | |
| eps: float = 1e-6, | |
| **kwargs, | |
| ): | |
| super().__init__(config, eps) | |
| def forward(self, rgbs, rays): | |
| B, _, H, W = rgbs.shape | |
| features, tokens = self.pixel_encoder(rgbs) | |
| inputs = {} | |
| inputs["image"] = rgbs | |
| inputs["rays"] = rays | |
| inputs["features"] = [ | |
| self.stacking_fn(features[i:j]).contiguous() | |
| for i, j in self.slices_encoder_range | |
| ] | |
| inputs["tokens"] = [ | |
| self.stacking_fn(tokens[i:j]).contiguous() | |
| for i, j in self.slices_encoder_range | |
| ] | |
| outputs = self.pixel_decoder(inputs, []) | |
| outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W) | |
| pts_3d = outputs["rays"] * outputs["radius"] | |
| return pts_3d, outputs["confidence"] | |
| def export(model, path, shape=(462, 630), with_camera=False): | |
| model.eval() | |
| image = torch.rand(1, 3, *shape) | |
| dynamic_axes_in = {"rgbs": {0: "batch"}} | |
| inputs = [image] | |
| if with_camera: | |
| rays = torch.rand(1, 3, *shape) | |
| inputs.append(rays) | |
| dynamic_axes_in["rays"] = {0: "batch"} | |
| dynamic_axes_out = { | |
| "pts_3d": {0: "batch"}, | |
| "confidence": {0: "batch"}, | |
| } | |
| torch.onnx.export( | |
| model, | |
| tuple(inputs), | |
| path, | |
| input_names=list(dynamic_axes_in.keys()), | |
| output_names=list(dynamic_axes_out.keys()), | |
| opset_version=14, | |
| dynamic_axes={**dynamic_axes_in, **dynamic_axes_out}, | |
| ) | |
| print(f"Model exported to {path}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Export UniK3D model to ONNX") | |
| parser.add_argument( | |
| "--backbone", | |
| type=str, | |
| default="vitl", | |
| choices=["vits", "vitb", "vitl"], | |
| help="Backbone model", | |
| ) | |
| parser.add_argument( | |
| "--shape", | |
| type=int, | |
| nargs=2, | |
| default=(462, 630), | |
| help="Input shape. No dyamic shape supported!", | |
| ) | |
| parser.add_argument( | |
| "--output-path", type=str, default="unik3d.onnx", help="Output ONNX file" | |
| ) | |
| parser.add_argument( | |
| "--with-camera", | |
| action="store_true", | |
| help="Export model that expects GT camera as unprojected rays at inference", | |
| ) | |
| args = parser.parse_args() | |
| backbone = args.backbone | |
| shape = args.shape | |
| output_path = args.output_path | |
| with_camera = args.with_camera | |
| # force shape to be multiple of 14 | |
| shape_rounded = [14 * ceil(x // 14 - 0.5) for x in shape] | |
| if list(shape) != list(shape_rounded): | |
| print(f"Shape {shape} is not multiple of 14. Rounding to {shape_rounded}") | |
| shape = shape_rounded | |
| # assumes command is from root of repo | |
| with open(os.path.join("configs", f"config_{backbone}.json")) as f: | |
| config = json.load(f) | |
| # tell DINO not to use efficient attention: not exportable | |
| config["training"]["export"] = True | |
| model = UniK3DONNX(config) if not with_camera else UniK3DONNXcam(config) | |
| path = huggingface_hub.hf_hub_download( | |
| repo_id=f"lpiccinelli/unik3d-{backbone}", | |
| filename=f"pytorch_model.bin", | |
| repo_type="model", | |
| ) | |
| info = model.load_state_dict(torch.load(path), strict=False) | |
| print(f"UUniK3D_{backbone} is loaded with:") | |
| print(f"\t missing keys: {info.missing_keys}") | |
| print(f"\t additional keys: {info.unexpected_keys}") | |
| export( | |
| model=model, | |
| path=os.path.join(os.environ.get("TMPDIR", "."), output_path), | |
| shape=shape, | |
| with_camera=with_camera, | |
| ) | |