MohamedRashad commited on
Commit
81535ba
·
1 Parent(s): 900119b

Refactor installation commands and improve helper functions for mesh extraction and rigging pipeline

Browse files
Files changed (1) hide show
  1. app.py +323 -412
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import shutil
2
  import subprocess
3
  import time
4
- import traceback
5
  from pathlib import Path
6
  from typing import Tuple
7
 
@@ -11,12 +10,7 @@ import spaces
11
  import torch
12
  import yaml
13
 
14
- subprocess.run([
15
- "pip", "install",
16
- "flash-attn",
17
- "--no-build-isolation",
18
- "--find-links", "https://github.com/Dao-AILab/flash-attention/releases"
19
- ], check=True)
20
 
21
  # Get the PyTorch and CUDA versions
22
  torch_version = torch.__version__.split("+")[0] # Strips any "+cuXXX" suffix
@@ -34,426 +28,343 @@ subprocess.run(f'pip install spconv{spconv_version}', shell=True)
34
  subprocess.run(f'pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html --no-cache-dir', shell=True)
35
 
36
 
37
- class UniRigDemo:
38
- """Main class for the UniRig Gradio demo application."""
39
-
40
- def __init__(self):
41
- # Create temp directory in current directory instead of system temp
42
- base_dir = Path(__file__).parent
43
- self.temp_dir = base_dir / "tmp"
44
- self.temp_dir.mkdir(exist_ok=True)
45
-
46
- # Supported file formats
47
- self.supported_formats = ['.obj', '.fbx', '.glb']
48
-
49
- def validate_input_file(self, file_path: str) -> bool:
50
- """Validate if the input file format is supported."""
51
- if not file_path or not Path(file_path).exists():
52
- return False
53
-
54
- file_ext = Path(file_path).suffix.lower()
55
- return file_ext in self.supported_formats
56
 
57
- def generate_skeleton(self, input_file: str, seed: int = 12345) -> Tuple[str, str, str]:
58
- """
59
- OPERATION 1: Generate skeleton for the input 3D model using Python
60
-
61
- Args:
62
- input_file: Path to the input 3D model file
63
- seed: Random seed for reproducible results
64
-
65
- Returns:
66
- Tuple of (status_message, output_file_path, preview_info)
67
- """
68
- # Validate input
69
- if not self.validate_input_file(input_file):
70
- return "Error: Invalid or unsupported file format. Supported: " + ", ".join(self.supported_formats), "", ""
71
-
72
- # Create working directory
73
- file_stem = Path(input_file).stem
74
- input_model_dir = self.temp_dir / f"{file_stem}_{seed}"
75
- input_model_dir.mkdir(exist_ok=True)
76
 
77
- # Copy input file to working directory
78
- input_file = Path(input_file)
79
- shutil.copy2(input_file, input_model_dir / input_file.name)
80
- input_file = input_model_dir / input_file.name
81
- print(f"New input file path: {input_file}")
82
-
83
- # Generate skeleton using Python (replaces bash script)
84
- output_file = input_model_dir / f"{file_stem}_skeleton.fbx"
85
-
86
- self.run_skeleton_inference_python(input_file, output_file, seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- if not output_file.exists():
89
- return "Error: Skeleton file was not generated", "", ""
90
-
91
- print(f"Generated skeleton at: {output_file}")
92
- return str(output_file)
 
93
 
94
- def merge_results(self, original_file: str, rigged_file: str, output_file) -> str:
95
- """
96
- OPERATION 3: Merge the rigged skeleton/skin with the original model using Python functions.
97
-
98
- Args:
99
- original_file: Path to the original 3D model
100
- rigged_file: Path to the rigged file (skeleton or skin)
101
-
102
- Returns:
103
- Tuple of (status_message, output_file_path, preview_info)
104
- """
105
- if not original_file or not Path(original_file).exists():
106
- return "Error: Original file not provided or doesn't exist", "", ""
107
-
108
- if not rigged_file or not Path(rigged_file).exists():
109
- return "Error: Rigged file not provided or doesn't exist", "", ""
110
-
111
- # Create output file
112
- work_dir = Path(rigged_file).parent
113
- output_file = work_dir / f"{Path(original_file).stem}_rigged.glb"
114
-
115
- # Run merge using Python function
116
- try:
117
- self.merge_results_python(rigged_file, original_file, str(output_file))
118
- except Exception as e:
119
- error_msg = f"Error: Merge failed: {str(e)}"
120
- traceback.print_exc()
121
- return error_msg, "", ""
122
-
123
- # Validate that the output file exists and is a file (not a directory)
124
- output_file_abs = output_file.resolve()
125
- if not output_file_abs.exists():
126
- return "Error: Merged file was not generated", "", ""
127
-
128
- if not output_file_abs.is_file():
129
- return f"Error: Output path is not a valid file: {output_file_abs}", "", ""
130
-
131
- # Generate preview information
132
- preview_info = self.generate_model_preview(str(output_file_abs))
133
-
134
- return "✅ Model rigging completed successfully!", str(output_file_abs), preview_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- @spaces.GPU()
137
- def complete_pipeline(self, input_file: str, seed: int = 12345) -> Tuple[str, str, str, str, str]:
138
- """
139
- Run the complete rigging pipeline: skeleton generation → skinning → merge.
140
-
141
- Args:
142
- input_file: Path to the input 3D model file
143
- seed: Random seed for reproducible results
 
 
 
 
 
 
 
 
 
 
144
 
145
- Returns:
146
- Tuple of status messages and file paths for each step
147
- """
148
- # Validate input file
149
- if not self.validate_input_file(input_file):
150
- raise gr.Error(f"Error: Invalid or unsupported file format. Supported formats: {', '.join(self.supported_formats)}")
151
-
152
- # Create working directory
153
- file_stem = Path(input_file).stem
154
- input_model_dir = self.temp_dir / f"{file_stem}_{seed}"
155
- input_model_dir.mkdir(exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- # Copy input file to working directory
158
- input_file = Path(input_file)
159
- shutil.copy2(input_file, input_model_dir / input_file.name)
160
- input_file = input_model_dir / input_file.name
161
- print(f"New input file path: {input_file}")
162
-
163
- # Step 1: Generate skeleton
164
- output_skeleton_file = input_model_dir / f"{file_stem}_skeleton.fbx"
165
- self.run_skeleton_inference_python(input_file, output_skeleton_file, seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- # Step 2: Generate skinning
168
- output_skin_file = input_model_dir / f"{file_stem}_skin.fbx"
169
- self.run_skin_inference_python(output_skeleton_file, output_skin_file)
170
-
171
- # Step 3: Merge results
172
- final_file = input_model_dir / f"{file_stem}_rigged.glb"
173
- self.merge_results_python(output_skin_file, input_file, final_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
- return str(final_file), [str(output_skeleton_file), str(output_skin_file), str(final_file)]
176
-
177
- def extract_mesh_python(self, input_file: str, output_dir: str) -> str:
178
- """
179
- Extract mesh data from 3D model using Python (replaces extract.sh)
180
- Returns path to generated .npz file
181
- """
182
- # Import required modules
183
- from src.data.extract import get_files, extract_builtin
184
-
185
- # Create extraction parameters
186
- files = get_files(
187
- data_name="raw_data.npz",
188
- inputs=str(input_file),
189
- input_dataset_dir=None,
190
- output_dataset_dir=output_dir,
191
- force_override=True,
192
- warning=False,
193
- )
194
-
195
- if not files:
196
- raise RuntimeError("No files to extract")
197
-
198
- # Run the actual extraction
199
- timestamp = str(int(time.time()))
200
- extract_builtin(
201
- output_folder=output_dir,
202
- target_count=50000,
203
- num_runs=1,
204
- id=0,
205
- time=timestamp,
206
- files=files,
207
- )
208
-
209
- # Return the directory path where raw_data.npz was created
210
- # The dataset expects to find raw_data.npz in this directory
211
- expected_npz_dir = files[0][1] # This is the output directory
212
- expected_npz_file = Path(expected_npz_dir) / "raw_data.npz"
213
-
214
- if not expected_npz_file.exists():
215
- raise RuntimeError(f"Extraction failed: {expected_npz_file} not found")
216
-
217
- return expected_npz_dir # Return the directory containing raw_data.npz
218
-
219
- def run_skeleton_inference_python(self, input_file: str, output_file: str, seed: int = 12345) -> str:
220
- """
221
- Run skeleton inference using Python (replaces skeleton part of generate_skeleton.sh)
222
- Returns path to skeleton FBX file
223
- """
224
- from box import Box
225
 
226
- from src.data.datapath import Datapath
227
- from src.data.dataset import DatasetConfig, UniRigDatasetModule
228
- from src.data.transform import TransformConfig
229
- from src.inference.download import download
230
- from src.model.parse import get_model
231
- from src.system.parse import get_system, get_writer
232
- from src.tokenizer.parse import get_tokenizer
233
- from src.tokenizer.spec import TokenizerConfig
234
-
235
- # Set random seed
236
- L.seed_everything(seed, workers=True)
237
-
238
- # Load task configuration
239
- task_config_path = "configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml"
240
- if not Path(task_config_path).exists():
241
- raise FileNotFoundError(f"Task configuration file not found: {task_config_path}")
242
-
243
- # Load the task configuration
244
- with open(task_config_path, 'r') as f:
245
- task = Box(yaml.safe_load(f))
246
-
247
- # Create temporary npz directory
248
- npz_dir = Path(output_file).parent / "npz"
249
- npz_dir.mkdir(exist_ok=True)
250
-
251
- # Extract mesh data
252
- npz_data_dir = self.extract_mesh_python(input_file, npz_dir)
253
-
254
- # Setup datapath with the directory containing raw_data.npz
255
- datapath = Datapath(files=[npz_data_dir], cls=None)
256
-
257
- # Load configurations
258
- data_config = Box(yaml.safe_load(open("configs/data/quick_inference.yaml", 'r')))
259
- transform_config = Box(yaml.safe_load(open("configs/transform/inference_ar_transform.yaml", 'r')))
260
-
261
- # Get tokenizer
262
- tokenizer_config = TokenizerConfig.parse(config=Box(yaml.safe_load(open("configs/tokenizer/tokenizer_parts_articulationxl_256.yaml", 'r'))))
263
- tokenizer = get_tokenizer(config=tokenizer_config)
264
-
265
- # Get model
266
- model_config = Box(yaml.safe_load(open("configs/model/unirig_ar_350m_1024_81920_float32.yaml", 'r')))
267
- model = get_model(tokenizer=tokenizer, **model_config)
268
-
269
- # Setup datasets and transforms
270
- predict_dataset_config = DatasetConfig.parse(config=data_config.predict_dataset_config).split_by_cls()
271
- predict_transform_config = TransformConfig.parse(config=transform_config.predict_transform_config)
272
-
273
- # Create data module
274
- data = UniRigDatasetModule(
275
- process_fn=model._process_fn,
276
- predict_dataset_config=predict_dataset_config,
277
- predict_transform_config=predict_transform_config,
278
- tokenizer_config=tokenizer_config,
279
- debug=False,
280
- data_name="raw_data.npz",
281
- datapath=datapath,
282
- cls=None,
283
- )
284
-
285
- # Setup callbacks and writer
286
- callbacks = []
287
- writer_config = task.writer.copy()
288
- writer_config['npz_dir'] = str(npz_dir)
289
- writer_config['output_dir'] = str(Path(output_file).parent)
290
- writer_config['output_name'] = Path(output_file).name
291
- writer_config['user_mode'] = False # Set to False to enable NPZ export
292
- print(f"Writer config: {writer_config}")
293
- # But we want the FBX to go to our specified location when in user mode for FBX
294
- callbacks.append(get_writer(**writer_config, order_config=predict_transform_config.order_config))
295
-
296
- # Get system
297
- system_config = Box(yaml.safe_load(open("configs/system/ar_inference_articulationxl.yaml", 'r')))
298
- system = get_system(**system_config, model=model, steps_per_epoch=1)
299
-
300
- # Setup trainer
301
- trainer_config = task.trainer
302
- resume_from_checkpoint = download(task.resume_from_checkpoint)
303
-
304
- trainer = L.Trainer(callbacks=callbacks, logger=None, **trainer_config)
305
-
306
- # Run prediction
307
- trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False)
308
-
309
- # The actual output file will be in a subdirectory named after the input file
310
- # Look for the generated skeleton.fbx file
311
- input_name_stem = Path(input_file).stem
312
- actual_output_dir = Path(output_file).parent / input_name_stem
313
- actual_output_file = actual_output_dir / "skeleton.fbx"
314
-
315
- if not actual_output_file.exists():
316
- # Try alternative locations - look for any skeleton.fbx file in the output directory
317
- alt_files = list(Path(output_file).parent.rglob("skeleton.fbx"))
318
- if alt_files:
319
- actual_output_file = alt_files[0]
320
- print(f"Found skeleton at alternative location: {actual_output_file}")
321
- else:
322
- # List all files for debugging
323
- all_files = list(Path(output_file).parent.rglob("*"))
324
- print(f"Available files: {[str(f) for f in all_files]}")
325
- raise RuntimeError(f"Skeleton FBX file not found. Expected at: {actual_output_file}")
326
-
327
- # Copy to the expected output location
328
- if actual_output_file != Path(output_file):
329
- shutil.copy2(actual_output_file, output_file)
330
- print(f"Copied skeleton from {actual_output_file} to {output_file}")
331
-
332
- print(f"Generated skeleton at: {output_file}")
333
- return str(output_file)
334
-
335
- def run_skin_inference_python(self, skeleton_file: str, output_file: str) -> str:
336
- """
337
- Run skin inference using Python (replaces skin part of generate_skin.sh)
338
- Returns path to skin FBX file
339
- """
340
- from box import Box
341
 
342
- from src.data.datapath import Datapath
343
- from src.data.dataset import DatasetConfig, UniRigDatasetModule
344
- from src.data.transform import TransformConfig
345
- from src.inference.download import download
346
- from src.model.parse import get_model
347
- from src.system.parse import get_system, get_writer
348
-
349
- # Load task configuration
350
- task_config_path = "configs/task/quick_inference_unirig_skin.yaml"
351
- with open(task_config_path, 'r') as f:
352
- task = Box(yaml.safe_load(f))
353
-
354
- # Look for files matching predict_skeleton.npz pattern recursively
355
- skeleton_work_dir = Path(skeleton_file).parent
356
- all_npz_files = list(skeleton_work_dir.rglob("**/*.npz"))
357
-
358
- # Setup datapath - need to pass the directory containing the NPZ file
359
- skeleton_npz_dir = all_npz_files[0].parent
360
- datapath = Datapath(files=[str(skeleton_npz_dir)], cls=None)
361
-
362
- # Load configurations
363
- data_config = Box(yaml.safe_load(open("configs/data/quick_inference.yaml", 'r')))
364
- transform_config = Box(yaml.safe_load(open("configs/transform/inference_skin_transform.yaml", 'r')))
365
-
366
- # Get model
367
- model_config = Box(yaml.safe_load(open("configs/model/unirig_skin.yaml", 'r')))
368
- model = get_model(tokenizer=None, **model_config)
369
-
370
- # Setup datasets and transforms
371
- predict_dataset_config = DatasetConfig.parse(config=data_config.predict_dataset_config).split_by_cls()
372
- predict_transform_config = TransformConfig.parse(config=transform_config.predict_transform_config)
373
-
374
- # Create data module
375
- data = UniRigDatasetModule(
376
- process_fn=model._process_fn,
377
- predict_dataset_config=predict_dataset_config,
378
- predict_transform_config=predict_transform_config,
379
- tokenizer_config=None,
380
- debug=False,
381
- data_name="predict_skeleton.npz",
382
- datapath=datapath,
383
- cls=None,
384
- )
385
-
386
- # Setup callbacks and writer
387
- callbacks = []
388
- writer_config = task.writer.copy()
389
- writer_config['npz_dir'] = str(skeleton_npz_dir)
390
- writer_config['output_name'] = str(output_file)
391
- writer_config['user_mode'] = True
392
- writer_config['export_fbx'] = True # Enable FBX export
393
- callbacks.append(get_writer(**writer_config, order_config=predict_transform_config.order_config))
394
-
395
- # Get system
396
- system_config = Box(yaml.safe_load(open("configs/system/skin.yaml", 'r')))
397
- system = get_system(**system_config, model=model, steps_per_epoch=1)
398
-
399
- # Setup trainer
400
- trainer_config = task.trainer
401
- resume_from_checkpoint = download(task.resume_from_checkpoint)
402
-
403
- trainer = L.Trainer(callbacks=callbacks, logger=None, **trainer_config)
404
-
405
- # Run prediction
406
- trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False)
407
-
408
- # The skin FBX file should be generated with the specified output name
409
- # Since user_mode is True and export_fbx is True, it should create the file directly
410
- if not Path(output_file).exists():
411
- # Look for generated skin FBX files in the output directory
412
- skin_files = list(Path(output_file).parent.rglob("*skin*.fbx"))
413
- if skin_files:
414
- actual_output_file = skin_files[0]
415
- # Copy/move to the expected location
416
- shutil.copy2(actual_output_file, output_file)
417
- else:
418
- raise RuntimeError(f"Skin FBX file not found. Expected at: {output_file}")
419
-
420
- return str(output_file)
421
-
422
- def merge_results_python(self, source_file: str, target_file: str, output_file: str) -> str:
423
- """
424
- Merge results using Python (replaces merge.sh)
425
- Returns path to merged file
426
- """
427
- from src.inference.merge import transfer
428
-
429
- # Validate input paths
430
- if not Path(source_file).exists():
431
- raise ValueError(f"Source file does not exist: {source_file}")
432
- if not Path(target_file).exists():
433
- raise ValueError(f"Target file does not exist: {target_file}")
434
-
435
- # Ensure output directory exists
436
- output_path = Path(output_file)
437
- output_path.parent.mkdir(parents=True, exist_ok=True)
438
-
439
- # Use the transfer function directly
440
- transfer(source=str(source_file), target=str(target_file), output=str(output_path), add_root=False)
441
-
442
- # Validate that the output file was created and is a valid file
443
- if not output_path.exists():
444
- raise RuntimeError(f"Merge failed: Output file not created at {output_path}")
445
-
446
- if not output_path.is_file():
447
- raise RuntimeError(f"Merge failed: Output path is not a valid file: {output_path}")
448
-
449
- return str(output_path.resolve())
450
 
451
 
452
  def create_app():
453
  """Create and configure the Gradio interface."""
454
 
455
- demo_instance = UniRigDemo()
456
-
457
  with gr.Blocks(title="UniRig - 3D Model Rigging Demo") as interface:
458
 
459
  # Header
@@ -502,7 +413,7 @@ def create_app():
502
  )
503
 
504
  pipeline_btn.click(
505
- fn=demo_instance.complete_pipeline,
506
  inputs=[input_3d_model, seed],
507
  outputs=[pipeline_skeleton_out, files_to_download]
508
  )
 
1
  import shutil
2
  import subprocess
3
  import time
 
4
  from pathlib import Path
5
  from typing import Tuple
6
 
 
10
  import torch
11
  import yaml
12
 
13
+ subprocess.run('pip install flash-attn --no-build-isolation', shell=True)
 
 
 
 
 
14
 
15
  # Get the PyTorch and CUDA versions
16
  torch_version = torch.__version__.split("+")[0] # Strips any "+cuXXX" suffix
 
28
  subprocess.run(f'pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html --no-cache-dir', shell=True)
29
 
30
 
31
+ # Helper functions
32
+ def validate_input_file(file_path: str) -> bool:
33
+ """Validate if the input file format is supported."""
34
+ supported_formats = ['.obj', '.fbx', '.glb']
35
+ if not file_path or not Path(file_path).exists():
36
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ file_ext = Path(file_path).suffix.lower()
39
+ return file_ext in supported_formats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ def extract_mesh_python(input_file: str, output_dir: str) -> str:
42
+ """
43
+ Extract mesh data from 3D model using Python (replaces extract.sh)
44
+ Returns path to generated .npz file
45
+ """
46
+ # Import required modules
47
+ from src.data.extract import get_files, extract_builtin
48
+
49
+ # Create extraction parameters
50
+ files = get_files(
51
+ data_name="raw_data.npz",
52
+ inputs=str(input_file),
53
+ input_dataset_dir=None,
54
+ output_dataset_dir=output_dir,
55
+ force_override=True,
56
+ warning=False,
57
+ )
58
+
59
+ if not files:
60
+ raise RuntimeError("No files to extract")
61
+
62
+ # Run the actual extraction
63
+ timestamp = str(int(time.time()))
64
+ extract_builtin(
65
+ output_folder=output_dir,
66
+ target_count=50000,
67
+ num_runs=1,
68
+ id=0,
69
+ time=timestamp,
70
+ files=files,
71
+ )
72
+
73
+ # Return the directory path where raw_data.npz was created
74
+ # The dataset expects to find raw_data.npz in this directory
75
+ expected_npz_dir = files[0][1] # This is the output directory
76
+ expected_npz_file = Path(expected_npz_dir) / "raw_data.npz"
77
+
78
+ if not expected_npz_file.exists():
79
+ raise RuntimeError(f"Extraction failed: {expected_npz_file} not found")
80
+
81
+ return expected_npz_dir # Return the directory containing raw_data.npz
82
 
83
+ def run_skeleton_inference_python(input_file: str, output_file: str, seed: int = 12345) -> str:
84
+ """
85
+ Run skeleton inference using Python (replaces skeleton part of generate_skeleton.sh)
86
+ Returns path to skeleton FBX file
87
+ """
88
+ from box import Box
89
 
90
+ from src.data.datapath import Datapath
91
+ from src.data.dataset import DatasetConfig, UniRigDatasetModule
92
+ from src.data.transform import TransformConfig
93
+ from src.inference.download import download
94
+ from src.model.parse import get_model
95
+ from src.system.parse import get_system, get_writer
96
+ from src.tokenizer.parse import get_tokenizer
97
+ from src.tokenizer.spec import TokenizerConfig
98
+
99
+ # Set random seed
100
+ L.seed_everything(seed, workers=True)
101
+
102
+ # Load task configuration
103
+ task_config_path = "configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml"
104
+ if not Path(task_config_path).exists():
105
+ raise FileNotFoundError(f"Task configuration file not found: {task_config_path}")
106
+
107
+ # Load the task configuration
108
+ with open(task_config_path, 'r') as f:
109
+ task = Box(yaml.safe_load(f))
110
+
111
+ # Create temporary npz directory
112
+ npz_dir = Path(output_file).parent / "npz"
113
+ npz_dir.mkdir(exist_ok=True)
114
+
115
+ # Extract mesh data
116
+ npz_data_dir = extract_mesh_python(input_file, npz_dir)
117
+
118
+ # Setup datapath with the directory containing raw_data.npz
119
+ datapath = Datapath(files=[npz_data_dir], cls=None)
120
+
121
+ # Load configurations
122
+ data_config = Box(yaml.safe_load(open("configs/data/quick_inference.yaml", 'r')))
123
+ transform_config = Box(yaml.safe_load(open("configs/transform/inference_ar_transform.yaml", 'r')))
124
+
125
+ # Get tokenizer
126
+ tokenizer_config = TokenizerConfig.parse(config=Box(yaml.safe_load(open("configs/tokenizer/tokenizer_parts_articulationxl_256.yaml", 'r'))))
127
+ tokenizer = get_tokenizer(config=tokenizer_config)
128
+
129
+ # Get model
130
+ model_config = Box(yaml.safe_load(open("configs/model/unirig_ar_350m_1024_81920_float32.yaml", 'r')))
131
+ model = get_model(tokenizer=tokenizer, **model_config)
132
+
133
+ # Setup datasets and transforms
134
+ predict_dataset_config = DatasetConfig.parse(config=data_config.predict_dataset_config).split_by_cls()
135
+ predict_transform_config = TransformConfig.parse(config=transform_config.predict_transform_config)
136
+
137
+ # Create data module
138
+ data = UniRigDatasetModule(
139
+ process_fn=model._process_fn,
140
+ predict_dataset_config=predict_dataset_config,
141
+ predict_transform_config=predict_transform_config,
142
+ tokenizer_config=tokenizer_config,
143
+ debug=False,
144
+ data_name="raw_data.npz",
145
+ datapath=datapath,
146
+ cls=None,
147
+ )
148
+
149
+ # Setup callbacks and writer
150
+ callbacks = []
151
+ writer_config = task.writer.copy()
152
+ writer_config['npz_dir'] = str(npz_dir)
153
+ writer_config['output_dir'] = str(Path(output_file).parent)
154
+ writer_config['output_name'] = Path(output_file).name
155
+ writer_config['user_mode'] = False # Set to False to enable NPZ export
156
+ print(f"Writer config: {writer_config}")
157
+ # But we want the FBX to go to our specified location when in user mode for FBX
158
+ callbacks.append(get_writer(**writer_config, order_config=predict_transform_config.order_config))
159
+
160
+ # Get system
161
+ system_config = Box(yaml.safe_load(open("configs/system/ar_inference_articulationxl.yaml", 'r')))
162
+ system = get_system(**system_config, model=model, steps_per_epoch=1)
163
+
164
+ # Setup trainer
165
+ trainer_config = task.trainer
166
+ resume_from_checkpoint = download(task.resume_from_checkpoint)
167
+
168
+ trainer = L.Trainer(callbacks=callbacks, logger=None, **trainer_config)
169
+
170
+ # Run prediction
171
+ trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False)
172
+
173
+ # The actual output file will be in a subdirectory named after the input file
174
+ # Look for the generated skeleton.fbx file
175
+ input_name_stem = Path(input_file).stem
176
+ actual_output_dir = Path(output_file).parent / input_name_stem
177
+ actual_output_file = actual_output_dir / "skeleton.fbx"
178
+
179
+ if not actual_output_file.exists():
180
+ # Try alternative locations - look for any skeleton.fbx file in the output directory
181
+ alt_files = list(Path(output_file).parent.rglob("skeleton.fbx"))
182
+ if alt_files:
183
+ actual_output_file = alt_files[0]
184
+ print(f"Found skeleton at alternative location: {actual_output_file}")
185
+ else:
186
+ # List all files for debugging
187
+ all_files = list(Path(output_file).parent.rglob("*"))
188
+ print(f"Available files: {[str(f) for f in all_files]}")
189
+ raise RuntimeError(f"Skeleton FBX file not found. Expected at: {actual_output_file}")
190
+
191
+ # Copy to the expected output location
192
+ if actual_output_file != Path(output_file):
193
+ shutil.copy2(actual_output_file, output_file)
194
+ print(f"Copied skeleton from {actual_output_file} to {output_file}")
195
+
196
+ print(f"Generated skeleton at: {output_file}")
197
+ return str(output_file)
198
 
199
+ def run_skin_inference_python(skeleton_file: str, output_file: str) -> str:
200
+ """
201
+ Run skin inference using Python (replaces skin part of generate_skin.sh)
202
+ Returns path to skin FBX file
203
+ """
204
+ from box import Box
205
+
206
+ from src.data.datapath import Datapath
207
+ from src.data.dataset import DatasetConfig, UniRigDatasetModule
208
+ from src.data.transform import TransformConfig
209
+ from src.inference.download import download
210
+ from src.model.parse import get_model
211
+ from src.system.parse import get_system, get_writer
212
+
213
+ # Load task configuration
214
+ task_config_path = "configs/task/quick_inference_unirig_skin.yaml"
215
+ with open(task_config_path, 'r') as f:
216
+ task = Box(yaml.safe_load(f))
217
 
218
+ # Look for files matching predict_skeleton.npz pattern recursively
219
+ skeleton_work_dir = Path(skeleton_file).parent
220
+ all_npz_files = list(skeleton_work_dir.rglob("**/*.npz"))
221
+
222
+ # Setup datapath - need to pass the directory containing the NPZ file
223
+ skeleton_npz_dir = all_npz_files[0].parent
224
+ datapath = Datapath(files=[str(skeleton_npz_dir)], cls=None)
225
+
226
+ # Load configurations
227
+ data_config = Box(yaml.safe_load(open("configs/data/quick_inference.yaml", 'r')))
228
+ transform_config = Box(yaml.safe_load(open("configs/transform/inference_skin_transform.yaml", 'r')))
229
+
230
+ # Get model
231
+ model_config = Box(yaml.safe_load(open("configs/model/unirig_skin.yaml", 'r')))
232
+ model = get_model(tokenizer=None, **model_config)
233
+
234
+ # Setup datasets and transforms
235
+ predict_dataset_config = DatasetConfig.parse(config=data_config.predict_dataset_config).split_by_cls()
236
+ predict_transform_config = TransformConfig.parse(config=transform_config.predict_transform_config)
237
+
238
+ # Create data module
239
+ data = UniRigDatasetModule(
240
+ process_fn=model._process_fn,
241
+ predict_dataset_config=predict_dataset_config,
242
+ predict_transform_config=predict_transform_config,
243
+ tokenizer_config=None,
244
+ debug=False,
245
+ data_name="predict_skeleton.npz",
246
+ datapath=datapath,
247
+ cls=None,
248
+ )
249
+
250
+ # Setup callbacks and writer
251
+ callbacks = []
252
+ writer_config = task.writer.copy()
253
+ writer_config['npz_dir'] = str(skeleton_npz_dir)
254
+ writer_config['output_name'] = str(output_file)
255
+ writer_config['user_mode'] = True
256
+ writer_config['export_fbx'] = True # Enable FBX export
257
+ callbacks.append(get_writer(**writer_config, order_config=predict_transform_config.order_config))
258
+
259
+ # Get system
260
+ system_config = Box(yaml.safe_load(open("configs/system/skin.yaml", 'r')))
261
+ system = get_system(**system_config, model=model, steps_per_epoch=1)
262
+
263
+ # Setup trainer
264
+ trainer_config = task.trainer
265
+ resume_from_checkpoint = download(task.resume_from_checkpoint)
266
+
267
+ trainer = L.Trainer(callbacks=callbacks, logger=None, **trainer_config)
268
+
269
+ # Run prediction
270
+ trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False)
271
+
272
+ # The skin FBX file should be generated with the specified output name
273
+ # Since user_mode is True and export_fbx is True, it should create the file directly
274
+ if not Path(output_file).exists():
275
+ # Look for generated skin FBX files in the output directory
276
+ skin_files = list(Path(output_file).parent.rglob("*skin*.fbx"))
277
+ if skin_files:
278
+ actual_output_file = skin_files[0]
279
+ # Copy/move to the expected location
280
+ shutil.copy2(actual_output_file, output_file)
281
+ else:
282
+ raise RuntimeError(f"Skin FBX file not found. Expected at: {output_file}")
283
+
284
+ return str(output_file)
285
 
286
+ def merge_results_python(source_file: str, target_file: str, output_file: str) -> str:
287
+ """
288
+ Merge results using Python (replaces merge.sh)
289
+ Returns path to merged file
290
+ """
291
+ from src.inference.merge import transfer
292
+
293
+ # Validate input paths
294
+ if not Path(source_file).exists():
295
+ raise ValueError(f"Source file does not exist: {source_file}")
296
+ if not Path(target_file).exists():
297
+ raise ValueError(f"Target file does not exist: {target_file}")
298
+
299
+ # Ensure output directory exists
300
+ output_path = Path(output_file)
301
+ output_path.parent.mkdir(parents=True, exist_ok=True)
302
+
303
+ # Use the transfer function directly
304
+ transfer(source=str(source_file), target=str(target_file), output=str(output_path), add_root=False)
305
+
306
+ # Validate that the output file was created and is a valid file
307
+ if not output_path.exists():
308
+ raise RuntimeError(f"Merge failed: Output file not created at {output_path}")
309
+
310
+ if not output_path.is_file():
311
+ raise RuntimeError(f"Merge failed: Output path is not a valid file: {output_path}")
312
+
313
+ return str(output_path.resolve())
314
 
315
+ @spaces.GPU()
316
+ def complete_pipeline(input_file: str, seed: int = 12345) -> Tuple[str, list]:
317
+ """
318
+ Run the complete rigging pipeline: skeleton generation → skinning → merge.
319
+
320
+ Args:
321
+ input_file: Path to the input 3D model file
322
+ seed: Random seed for reproducible results
323
+
324
+ Returns:
325
+ Tuple of (final_file_path, list_of_intermediate_files)
326
+ """
327
+ # Create temp directory
328
+ base_dir = Path(__file__).parent
329
+ temp_dir = base_dir / "tmp"
330
+ temp_dir.mkdir(exist_ok=True)
331
+
332
+ # Supported file formats
333
+ supported_formats = ['.obj', '.fbx', '.glb']
334
+
335
+ # Validate input file
336
+ if not validate_input_file(input_file):
337
+ raise gr.Error(f"Error: Invalid or unsupported file format. Supported formats: {', '.join(supported_formats)}")
338
+
339
+ # Create working directory
340
+ file_stem = Path(input_file).stem
341
+ input_model_dir = temp_dir / f"{file_stem}_{seed}"
342
+ input_model_dir.mkdir(exist_ok=True)
343
 
344
+ # Copy input file to working directory
345
+ input_file = Path(input_file)
346
+ shutil.copy2(input_file, input_model_dir / input_file.name)
347
+ input_file = input_model_dir / input_file.name
348
+ print(f"New input file path: {input_file}")
349
+
350
+ # Step 1: Generate skeleton
351
+ output_skeleton_file = input_model_dir / f"{file_stem}_skeleton.fbx"
352
+ run_skeleton_inference_python(input_file, output_skeleton_file, seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
+ # Step 2: Generate skinning
355
+ output_skin_file = input_model_dir / f"{file_stem}_skin.fbx"
356
+ run_skin_inference_python(output_skeleton_file, output_skin_file)
357
+
358
+ # Step 3: Merge results
359
+ final_file = input_model_dir / f"{file_stem}_rigged.glb"
360
+ merge_results_python(output_skin_file, input_file, final_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
+ return str(final_file), [str(output_skeleton_file), str(output_skin_file), str(final_file)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
 
365
  def create_app():
366
  """Create and configure the Gradio interface."""
367
 
 
 
368
  with gr.Blocks(title="UniRig - 3D Model Rigging Demo") as interface:
369
 
370
  # Header
 
413
  )
414
 
415
  pipeline_btn.click(
416
+ fn=complete_pipeline,
417
  inputs=[input_3d_model, seed],
418
  outputs=[pipeline_skeleton_out, files_to_download]
419
  )