MM-MVR commited on
Commit
97bc03d
·
verified ·
1 Parent(s): a503028

Upload files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/editing.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/understand.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,13 @@
1
  ---
2
  title: STAR
3
- emoji: 💻
4
- colorFrom: blue
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: STAR Demo
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: STAR
3
+ emoji: 👁
4
+ colorFrom: green
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import spaces
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import random
8
+ import time
9
+ from PIL import Image
10
+ from huggingface_hub import hf_hub_download
11
+ import subprocess
12
+ subprocess.run(
13
+ "pip install flash-attn==2.7.3 --no-build-isolation",
14
+ shell=True
15
+ )
16
+
17
+ from star.models.config import load_config_from_json, STARMultiModalConfig
18
+ from star.models.model import STARMultiModal
19
+
20
+
21
+ TEXTS = {
22
+ "zh": {
23
+ "title": "🌟 STAR 多模态演示",
24
+ "description": "基于STAR模型的多模态AI演示系统,支持文本生成图像、图像编辑和图像理解功能。",
25
+ "please_load_model": "请先加载模型!",
26
+ "please_upload_image": "请上传图像!",
27
+ "generation_failed": "生成失败!",
28
+ "generation_success_diffusion": "生成成功!",
29
+ "generation_success_vq": "生成成功!",
30
+ "edit_failed": "编辑失败!",
31
+ "edit_success_diffusion": "编辑成功!",
32
+ "edit_success_vq": "编辑成功!",
33
+ "understanding_failed": "理解失败!",
34
+ "generation_error": "生成过程中出错: ",
35
+ "edit_error": "编辑过程中出错: ",
36
+ "understanding_error": "理解过程中出错: ",
37
+ "tab_text_to_image": "🖼️ 文本生成图像",
38
+ "tab_image_edit": "🖌️ 图像编辑",
39
+ "tab_image_understanding": "📝 图像理解",
40
+ "text_prompt": "文本提示",
41
+ "text_prompt_placeholder": "A whimsical scene featuring a small elf with pointed ears and a green hat, sipping orange juice through a long straw from a disproportionately large orange. Next to the elf, a curious squirrel perches on its hind legs, while an owl with wide, observant eyes watches intently from a branch overhead. The orange's vibrant color contrasts with the muted browns and greens of the surrounding forest foliage.",
42
+ "advanced_params": "高级参数",
43
+ "cfg_scale": "CFG Scale",
44
+ "cfg_scale_info": "控制生成图像与文本的匹配程度",
45
+ "top_k": "Top-K",
46
+ "top_k_info": "采样时考虑的token数量",
47
+ "top_p": "Top-P",
48
+ "top_p_info": "核采样参数",
49
+ "generate_image": "🎨 生成图像",
50
+ "generated_image": "生成的图像",
51
+ "generation_status": "生成状态",
52
+ "input_image": "输入图像",
53
+ "edit_instruction": "编辑指令",
54
+ "edit_instruction_placeholder": "Remove the tiger in the water.",
55
+ "edit_image": "✏️ 编辑图像",
56
+ "edited_image": "编辑后的图像",
57
+ "edit_status": "编辑状态",
58
+ "question": "问题",
59
+ "question_placeholder": "Please describe the content of this image",
60
+ "max_generation_length": "最大生成长度",
61
+ "understand_image": "🔍 理解图像",
62
+ "understanding_result": "理解结果",
63
+ "usage_instructions": "使用说明",
64
+ "usage_step1": "1. **文本生成图像**: 输入文本描述,调整参数后点击生成",
65
+ "usage_step2": "2. **图像编辑**: 上传图像并输入编辑指令",
66
+ "usage_step3": "3. **图像理解**: 上传图像并提出问题",
67
+ "language": "语言 / Language"
68
+ },
69
+ "en": {
70
+ "title": "🌟 STAR Multi-Modal Demo",
71
+ "description": "A multi-modal AI demonstration system based on STAR model, supporting text-to-image generation, image editing, and image understanding.",
72
+ "please_load_model": "Please load the model first!",
73
+ "please_upload_image": "Please upload an image!",
74
+ "generation_failed": "Generation failed!",
75
+ "generation_success_diffusion": "Generation successful! ",
76
+ "generation_success_vq": "Generation successful! Using VQ decoder",
77
+ "edit_failed": "Editing failed!",
78
+ "edit_success_diffusion": "Editing successful! ",
79
+ "edit_success_vq": "Editing successful! Using VQ decoder",
80
+ "understanding_failed": "Understanding failed!",
81
+ "generation_error": "Error during generation: ",
82
+ "edit_error": "Error during editing: ",
83
+ "understanding_error": "Error during understanding: ",
84
+ "tab_text_to_image": "🖼️ Text to Image",
85
+ "tab_image_edit": "🖌️ Image Editing",
86
+ "tab_image_understanding": "📝 Image Understanding",
87
+ "text_prompt": "Text Prompt",
88
+ "text_prompt_placeholder": "A whimsical scene featuring a small elf with pointed ears and a green hat, sipping orange juice through a long straw from a disproportionately large orange. Next to the elf, a curious squirrel perches on its hind legs, while an owl with wide, observant eyes watches intently from a branch overhead. The orange's vibrant color contrasts with the muted browns and greens of the surrounding forest foliage.",
89
+ "advanced_params": "Advanced Parameters",
90
+ "cfg_scale": "CFG Scale",
91
+ "cfg_scale_info": "Controls how closely the generated image matches the text",
92
+ "top_k": "Top-K",
93
+ "top_k_info": "Number of tokens to consider during sampling",
94
+ "top_p": "Top-P",
95
+ "top_p_info": "Nucleus sampling parameter",
96
+ "generate_image": "🎨 Generate Image",
97
+ "generated_image": "Generated Image",
98
+ "generation_status": "Generation Status",
99
+ "input_image": "Input Image",
100
+ "edit_instruction": "Edit Instruction",
101
+ "edit_instruction_placeholder": "Remove the tiger in the water.",
102
+ "edit_image": "✏️ Edit Image",
103
+ "edited_image": "Edited Image",
104
+ "edit_status": "Edit Status",
105
+ "question": "Question",
106
+ "question_placeholder": "Please describe the content of this image",
107
+ "max_generation_length": "Max Generation Length",
108
+ "understand_image": "🔍 Understand Image",
109
+ "understanding_result": "Understanding Result",
110
+ "usage_instructions": "Usage Instructions",
111
+ "usage_step1": "1. **Text to Image**: Enter text description, adjust parameters and click generate",
112
+ "usage_step2": "2. **Image Editing**: Upload an image and enter editing instructions",
113
+ "usage_step3": "3. **Image Understanding**: Upload an image and ask questions",
114
+ "language": "语言 / Language"
115
+ }
116
+ }
117
+
118
+ class MockArgs:
119
+ def __init__(self):
120
+ self.data_type = "generation"
121
+ self.diffusion_as_decoder = True
122
+ self.ori_inp_dit = "seq"
123
+ self.grad_ckpt = False
124
+ self.diffusion_resolution = 1024
125
+ self.max_diff_seq_length = 256
126
+ self.max_seq_length = 8192
127
+ self.max_text_tokens = 512
128
+ self.max_pixels = 28 * 28 * 576
129
+ self.min_pixels = 28 * 28 * 16
130
+ self.vq_image_size = 384
131
+ self.vq_tokens = 576
132
+
133
+
134
+ def set_seed(seed=100):
135
+ if seed > 0:
136
+ random.seed(seed)
137
+ np.random.seed(seed)
138
+ torch.manual_seed(seed)
139
+ if torch.cuda.is_available():
140
+ torch.cuda.manual_seed(seed)
141
+ torch.cuda.manual_seed_all(seed)
142
+ torch.backends.cudnn.deterministic = True
143
+ torch.backends.cudnn.benchmark = False
144
+ return seed
145
+
146
+
147
+ def print_with_time(msg):
148
+ print(f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}: {msg}")
149
+
150
+
151
+ class STARInferencer:
152
+
153
+ def __init__(self, model_config_path, checkpoint_path, vq_checkpoint, device="cpu"):
154
+ self.device = device
155
+ self.model_config_path = model_config_path
156
+ self.checkpoint_path = checkpoint_path
157
+ self.vq_checkpint_path = vq_checkpoint
158
+ self.model = None
159
+ self._load_model()
160
+
161
+ def _create_mock_args(self):
162
+
163
+ return MockArgs()
164
+
165
+ def _load_model(self):
166
+ try:
167
+ print_with_time("Loading model configuration...")
168
+ config_data = load_config_from_json(self.model_config_path)
169
+ model_config = STARMultiModalConfig(**config_data)
170
+
171
+ model_config.language_model.model_path = "Qwen/Qwen2.5-VL-7B-Instruct"
172
+ model_config.pixel_encoder.model_path = self.vq_checkpint_path
173
+ model_config.pixel_decoder.model_path = "Alpha-VLLM/Lumina-Image-2.0"
174
+
175
+ args = self._create_mock_args()
176
+
177
+ print_with_time("Initializing model...")
178
+ self.model = STARMultiModal(model_config, args)
179
+
180
+ if os.path.exists(self.checkpoint_path):
181
+ print_with_time(f"Loading checkpoint from {self.checkpoint_path}")
182
+ with torch.no_grad():
183
+ checkpoint = torch.load(self.checkpoint_path, map_location='cpu', weights_only=False)
184
+ if 'state_dict' in checkpoint:
185
+ state_dict = checkpoint['state_dict']
186
+ else:
187
+ state_dict = checkpoint
188
+
189
+ if not isinstance(state_dict, dict):
190
+ raise ValueError("Invalid checkpoint format")
191
+
192
+ print_with_time(f"Checkpoint contains {len(state_dict)} parameters")
193
+ self.model.load_state_dict(state_dict, strict=False)
194
+
195
+ print_with_time(f"Moving model to device: {self.device}")
196
+ self.model.to(self.device)
197
+
198
+ print_with_time("Setting model to eval mode...")
199
+ self.model.eval()
200
+
201
+ if torch.cuda.is_available():
202
+ print_with_time(f"GPU memory after model loading: {torch.cuda.memory_allocated()/1024**3:.2f}GB")
203
+
204
+ print_with_time("Model loaded successfully!")
205
+
206
+ except Exception as e:
207
+ print_with_time(f"Error loading model: {str(e)}")
208
+ import traceback
209
+ traceback.print_exc()
210
+ raise e
211
+
212
+ @spaces.GPU(duration=210)
213
+ def generate_image(self, prompt, num_images=1, cfg=20.0, topk=2000, topp=1.0, seed=0):
214
+
215
+ if self.model.device.type == 'cpu':
216
+ print_with_time("Moving model to GPU...")
217
+ self.model.to('cuda')
218
+ self.model.to(torch.bfloat16)
219
+ print_with_time("Model moved to GPU")
220
+
221
+ set_seed(seed)
222
+
223
+ print_with_time(f"Generating image for prompt: {prompt}")
224
+
225
+ cfg = max(1.0, min(20.0, float(cfg)))
226
+ topk = max(100, min(2000, int(topk)))
227
+ topp = max(0.1, min(1.0, float(topp)))
228
+
229
+ print_with_time(f"Using validated params: cfg={cfg}, topk={topk}, topp={topp}")
230
+
231
+ if not (torch.isfinite(torch.tensor(cfg)) and torch.isfinite(torch.tensor(topk)) and torch.isfinite(torch.tensor(topp))):
232
+ print_with_time("Warning: Non-finite parameters detected")
233
+ return None
234
+
235
+ try:
236
+ with torch.no_grad():
237
+ if torch.cuda.is_available():
238
+ torch.cuda.empty_cache()
239
+ print_with_time(f"GPU memory before generation: {torch.cuda.memory_allocated()/1024**3:.2f}GB")
240
+
241
+ if not isinstance(prompt, str) or len(prompt.strip()) == 0:
242
+ print_with_time("Warning: Invalid prompt")
243
+ return None
244
+
245
+ if not (0 < cfg <= 20 and 0 < topk <= 5000 and 0 < topp <= 1):
246
+ print_with_time(f"Warning: Invalid parameters - cfg={cfg}, topk={topk}, topp={topp}")
247
+ return None
248
+
249
+ print_with_time("Calling model.generate_images...")
250
+
251
+ safe_max_tokens = 576
252
+
253
+ output = self.model.generate_images(
254
+ prompt,
255
+ max_new_tokens=safe_max_tokens,
256
+ num_return_sequences=num_images,
257
+ cfg_weight=cfg,
258
+ topk_sample=topk,
259
+ topp_sample=topp,
260
+ reasoning=False,
261
+ return_dict=True
262
+ )
263
+ print_with_time("Model generation completed")
264
+
265
+ if output is None:
266
+ print_with_time("Warning: Model returned None output")
267
+ return None
268
+
269
+ print_with_time("Processing output images...")
270
+ result = self._process_output_images(output, num_images)
271
+ print_with_time("Image processing completed")
272
+ return result
273
+ except Exception as e:
274
+ print_with_time(f"Error during image generation: {str(e)}")
275
+ import traceback
276
+ traceback.print_exc()
277
+ if torch.cuda.is_available():
278
+ torch.cuda.empty_cache()
279
+ raise e
280
+
281
+ @spaces.GPU(duration=210)
282
+ def edit_image(self, image, instruction, num_images=1, cfg=20.0, topk=2000, topp=1.0, seed=0):
283
+
284
+ if self.model.device.type == 'cpu':
285
+ print_with_time("Moving model to GPU...")
286
+ self.model.to('cuda')
287
+ self.model.to(torch.bfloat16)
288
+ print_with_time("Model moved to GPU")
289
+
290
+ set_seed(seed)
291
+
292
+ if isinstance(image, np.ndarray):
293
+ image = Image.fromarray(image)
294
+
295
+ print_with_time(f"Editing image with instruction: {instruction}")
296
+
297
+ with torch.no_grad():
298
+ output = self.model.generate_images_edit(
299
+ [image],
300
+ instruction,
301
+ max_new_tokens=576,
302
+ num_return_sequences=num_images,
303
+ cfg_weight=cfg,
304
+ topk_sample=topk,
305
+ topp_sample=topp,
306
+ return_dict=True
307
+ )
308
+
309
+ if output is None:
310
+ return None
311
+
312
+ return self._process_output_images(output, num_images)
313
+
314
+ @spaces.GPU(duration=180)
315
+ def understand_image(self, image, question, max_new_tokens=256):
316
+
317
+ if self.model.device.type == 'cpu':
318
+ print_with_time("Moving model to GPU...")
319
+ self.model.to('cuda')
320
+ self.model.to(torch.bfloat16)
321
+ print_with_time("Model moved to GPU")
322
+
323
+ if isinstance(image, np.ndarray):
324
+ image = Image.fromarray(image)
325
+
326
+ print_with_time(f"Understanding image with question: {question}")
327
+
328
+ with torch.no_grad():
329
+ answer = self.model.inference_understand(
330
+ image=image,
331
+ question=question,
332
+ max_new_tokens=max_new_tokens
333
+ )
334
+
335
+ return answer
336
+
337
+ def _process_output_images(self, output, num_images):
338
+ image_size = 384
339
+
340
+ try:
341
+ if isinstance(output, dict):
342
+ output_images = output.get("output_images")
343
+ diff_images = output.get("diff_images")
344
+
345
+ results = {}
346
+
347
+ if output_images is not None:
348
+ if isinstance(output_images, torch.Tensor):
349
+ output_images = output_images.detach().cpu().numpy()
350
+
351
+ if output_images.size == 0:
352
+ print_with_time("Warning: Empty output_images array")
353
+ results["vq_images"] = None
354
+ else:
355
+ output_images = np.nan_to_num(output_images, nan=0.0, posinf=1.0, neginf=-1.0)
356
+ dec_vq = np.clip((output_images + 1) / 2 * 255, 0, 255)
357
+
358
+ if len(dec_vq.shape) == 3:
359
+ dec_vq = dec_vq.reshape(num_images, image_size, image_size, 3)
360
+
361
+ visual_img_vq = np.zeros((num_images, image_size, image_size, 3), dtype=np.uint8)
362
+ visual_img_vq[:, :, :] = dec_vq
363
+ imgs_vq = [Image.fromarray(visual_img_vq[j].astype(np.uint8)) for j in range(visual_img_vq.shape[0])]
364
+ results["vq_images"] = imgs_vq
365
+
366
+ if diff_images is not None:
367
+ results["diff_images"] = diff_images
368
+ else:
369
+ results["diff_images"] = None
370
+
371
+ return results
372
+ else:
373
+ if isinstance(output, torch.Tensor):
374
+ output = output.detach().cpu().numpy()
375
+
376
+ output = np.nan_to_num(output, nan=0.0, posinf=1.0, neginf=-1.0)
377
+ dec = np.clip((output + 1) / 2 * 255, 0, 255)
378
+
379
+ if len(dec.shape) == 3:
380
+ dec = dec.reshape(num_images, image_size, image_size, 3)
381
+
382
+ visual_img = np.zeros((num_images, image_size, image_size, 3), dtype=np.uint8)
383
+ visual_img[:, :, :] = dec
384
+ imgs = [Image.fromarray(visual_img[j].astype(np.uint8)) for j in range(visual_img.shape[0])]
385
+ return {"vq_images": imgs, "diff_images": None}
386
+
387
+ except Exception as e:
388
+ print_with_time(f"Error in _process_output_images: {str(e)}")
389
+ return {"vq_images": None, "diff_images": None}
390
+
391
+
392
+ inferencer = None
393
+
394
+
395
+
396
+ def save_language_setting(language):
397
+ try:
398
+ with open('.language_setting', 'w') as f:
399
+ f.write(language)
400
+ except:
401
+ pass
402
+
403
+ def update_interface_language(language):
404
+ global current_language
405
+ current_language = language
406
+
407
+ save_language_setting(language)
408
+
409
+ return [
410
+ language,
411
+ f"# {get_text('title')}",
412
+ get_text("description"),
413
+ get_text("text_prompt_placeholder"),
414
+ get_text("edit_instruction_placeholder"),
415
+ get_text("question_placeholder"),
416
+ f"""
417
+ ---
418
+ ### {get_text("usage_instructions")}
419
+ {get_text("usage_step1")}
420
+ {get_text("usage_step2")}
421
+ {get_text("usage_step3")}
422
+ """,
423
+ f"✅ Language switched to {language.upper()} successfully! / 语言已成功切换为{language.upper()}!" # 状态消息
424
+ ]
425
+
426
+ current_language = "en"
427
+
428
+ def get_text(key):
429
+ return TEXTS[current_language].get(key, key)
430
+
431
+
432
+ def auto_detect_device():
433
+ if torch.cuda.is_available():
434
+ device = f"cuda:{torch.cuda.current_device()}"
435
+ print_with_time(f"Detected CUDA device: {device}")
436
+ print_with_time(f"GPU name: {torch.cuda.get_device_name()}")
437
+ print_with_time(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
438
+ else:
439
+ device = "cpu"
440
+ print_with_time("No CUDA device detected, using CPU")
441
+ return device
442
+
443
+
444
+ def initialize_model_on_startup():
445
+ global inferencer
446
+
447
+ default_checkpoint = hf_hub_download(
448
+ repo_id="MM-MVR/STAR-7B",
449
+ filename="STAR-7B.pt"
450
+ )
451
+
452
+ default_config = "star/configs/STAR_Qwen2.5-VL-7B.json"
453
+
454
+ vq_checkpoint = hf_hub_download(
455
+ repo_id="MM-MVR/STAR-VQ",
456
+ filename="VQ-Model.pt"
457
+ )
458
+
459
+
460
+ if not os.path.exists(default_config):
461
+ print_with_time(f"⚠️ Model config file not found: {default_config}")
462
+ return False, f"Model config file not found: {default_config}"
463
+
464
+ if not os.path.exists(default_checkpoint):
465
+ print_with_time(f"⚠️ Model checkpoint file not found: {default_checkpoint}")
466
+ return False, f"Model checkpoint file not found: {default_checkpoint}"
467
+
468
+ try:
469
+ device = 'cpu'
470
+ print_with_time("Starting to load STAR model...")
471
+
472
+ inferencer = STARInferencer(default_config, default_checkpoint, vq_checkpoint, device)
473
+
474
+ print_with_time("✅ STAR model loaded successfully!")
475
+ return True, "✅ STAR model loaded successfully!"
476
+
477
+ except Exception as e:
478
+ error_msg = f"❌ Model loading failed: {str(e)}"
479
+ print_with_time(error_msg)
480
+ return False, error_msg
481
+
482
+
483
+
484
+
485
+ def text_to_image(prompt, cfg_scale=1.0, topk=1000, topp=0.8):
486
+ if inferencer is None:
487
+ return None, get_text("please_load_model")
488
+
489
+ cfg_scale = max(1.0, min(20.0, cfg_scale))
490
+ topk = max(100, min(2000, int(topk)))
491
+ topp = max(0.1, min(1.0, topp))
492
+ seed = 100
493
+
494
+ try:
495
+ print_with_time(f"Starting generation with params: cfg={cfg_scale}, topk={topk}, topp={topp}, seed={seed}")
496
+ result = inferencer.generate_image(prompt, cfg=cfg_scale, topk=topk, topp=topp, seed=seed)
497
+
498
+ if result is None:
499
+ return None, get_text("generation_failed")
500
+
501
+ if result.get("diff_images") and len(result["diff_images"]) > 0:
502
+ return result["diff_images"][0], get_text("generation_success_diffusion")
503
+ elif result.get("vq_images") and len(result["vq_images"]) > 0:
504
+ return result["vq_images"][0], get_text("generation_success_vq")
505
+ else:
506
+ return None, get_text("generation_failed")
507
+
508
+ except Exception as e:
509
+ return None, get_text("generation_error") + str(e)
510
+
511
+
512
+ def image_editing(image, instruction, cfg_scale=1.0, topk=1000, topp=0.8):
513
+ if inferencer is None:
514
+ return None, get_text("please_load_model")
515
+
516
+ if image is None:
517
+ return None, get_text("please_upload_image")
518
+
519
+
520
+ cfg_scale = max(1.0, min(20.0, cfg_scale))
521
+ topk = max(100, min(2000, int(topk)))
522
+ topp = max(0.1, min(1.0, topp))
523
+ seed = 100
524
+
525
+ try:
526
+ print_with_time(f"Starting image editing with params: cfg={cfg_scale}, topk={topk}, topp={topp}, seed={seed}")
527
+ result = inferencer.edit_image(image, instruction, cfg=cfg_scale, topk=topk, topp=topp, seed=seed)
528
+
529
+ if result is None:
530
+ return None, get_text("edit_failed")
531
+
532
+ if result.get("diff_images") and len(result["diff_images"]) > 0:
533
+ return result["diff_images"][0], get_text("edit_success_diffusion")
534
+ elif result.get("vq_images") and len(result["vq_images"]) > 0:
535
+ return result["vq_images"][0], get_text("edit_success_vq")
536
+ else:
537
+ return None, get_text("edit_failed")
538
+
539
+ except Exception as e:
540
+ return None, get_text("edit_error") + str(e)
541
+
542
+
543
+ def image_understanding(image, question, max_new_tokens=256):
544
+ if inferencer is None:
545
+ return get_text("please_load_model")
546
+
547
+ if image is None:
548
+ return get_text("please_upload_image")
549
+
550
+ try:
551
+ answer = inferencer.understand_image(image, question, max_new_tokens)
552
+ return answer if answer else get_text("understanding_failed")
553
+
554
+ except Exception as e:
555
+ return get_text("understanding_error") + str(e)
556
+
557
+
558
+ def change_language(language):
559
+ global current_language
560
+ current_language = language
561
+
562
+ return (
563
+ get_text("title"),
564
+ get_text("description"),
565
+ get_text("tab_text_to_image"),
566
+ get_text("text_prompt"),
567
+ get_text("text_prompt_placeholder"),
568
+ get_text("advanced_params"),
569
+ get_text("cfg_scale"),
570
+ get_text("cfg_scale_info"),
571
+ get_text("top_k"),
572
+ get_text("top_k_info"),
573
+ get_text("top_p"),
574
+ get_text("top_p_info"),
575
+ get_text("random_seed"),
576
+ get_text("random_seed_info"),
577
+ get_text("generate_image"),
578
+ get_text("generated_image"),
579
+ get_text("generation_status"),
580
+ get_text("tab_image_edit"),
581
+ get_text("input_image"),
582
+ get_text("edit_instruction"),
583
+ get_text("edit_instruction_placeholder"),
584
+ get_text("edit_image"),
585
+ get_text("edited_image"),
586
+ get_text("edit_status"),
587
+ get_text("tab_image_understanding"),
588
+ get_text("question"),
589
+ get_text("question_placeholder"),
590
+ get_text("max_generation_length"),
591
+ get_text("understand_image"),
592
+ get_text("understanding_result"),
593
+ get_text("usage_instructions"),
594
+ get_text("usage_step1"),
595
+ get_text("usage_step2"),
596
+ get_text("usage_step3")
597
+ )
598
+
599
+
600
+ def load_example_image(image_path):
601
+ try:
602
+ if os.path.exists(image_path):
603
+ return Image.open(image_path)
604
+ except Exception as e:
605
+ print(f"Error loading example image: {e}")
606
+ return None
607
+
608
+
609
+
610
+ def create_interface():
611
+
612
+ print_with_time("Initializing STAR demo system...")
613
+ model_loaded, status_message = initialize_model_on_startup()
614
+
615
+ with gr.Blocks(title="🌟 STAR Multi-Modal Demo", theme=gr.themes.Soft()) as demo:
616
+
617
+ language_state = gr.State(value=current_language)
618
+ title_md = gr.Markdown(f"# {get_text('title')}")
619
+ desc_md = gr.Markdown(get_text("description"))
620
+
621
+ with gr.Row():
622
+ with gr.Column():
623
+ language_dropdown = gr.Dropdown(
624
+ choices=[("English", "en"), ("中文", "zh")],
625
+ value=current_language,
626
+ label="Language / 语言",
627
+ interactive=True
628
+ )
629
+
630
+ with gr.Tabs():
631
+ with gr.Tab(get_text("tab_text_to_image")) as txt_tab:
632
+ with gr.Row():
633
+ with gr.Column():
634
+ txt_prompt = gr.Textbox(
635
+ label=get_text("text_prompt"),
636
+ value=get_text("text_prompt_placeholder"),
637
+ lines=3
638
+ )
639
+
640
+ with gr.Accordion(get_text("advanced_params"), open=False):
641
+ txt_cfg_scale = gr.Slider(
642
+ minimum=1.0, maximum=20.0, value=1.1, step=0.1,
643
+ label=get_text("cfg_scale"), info=get_text("cfg_scale_info")
644
+ )
645
+ txt_topk = gr.Slider(
646
+ minimum=100, maximum=2000, value=1000, step=50,
647
+ label=get_text("top_k"), info=get_text("top_k_info")
648
+ )
649
+ txt_topp = gr.Slider(
650
+ minimum=0.1, maximum=1.0, value=0.8, step=0.05,
651
+ label=get_text("top_p"), info=get_text("top_p_info")
652
+ )
653
+
654
+ txt_generate_btn = gr.Button(get_text("generate_image"), variant="primary")
655
+
656
+ with gr.Column():
657
+ txt_output_image = gr.Image(label=get_text("generated_image"))
658
+ txt_status = gr.Textbox(label=get_text("generation_status"), interactive=False)
659
+
660
+
661
+ with gr.Tab(get_text("tab_image_edit")) as edit_tab:
662
+ with gr.Row():
663
+ with gr.Column():
664
+ edit_input_image = gr.Image(
665
+ label=get_text("input_image"),
666
+ value=load_example_image('assets/editing.png')
667
+ )
668
+ edit_instruction = gr.Textbox(
669
+ label=get_text("edit_instruction"),
670
+ value=get_text("edit_instruction_placeholder"),
671
+ lines=2
672
+ )
673
+
674
+ with gr.Accordion(get_text("advanced_params"), open=False):
675
+ edit_cfg_scale = gr.Slider(
676
+ minimum=1.0, maximum=20.0, value=1.1, step=0.1,
677
+ label=get_text("cfg_scale")
678
+ )
679
+ edit_topk = gr.Slider(
680
+ minimum=100, maximum=2000, value=1000, step=50,
681
+ label=get_text("top_k")
682
+ )
683
+ edit_topp = gr.Slider(
684
+ minimum=0.1, maximum=1.0, value=0.8, step=0.05,
685
+ label=get_text("top_p")
686
+ )
687
+
688
+ edit_btn = gr.Button(get_text("edit_image"), variant="primary")
689
+
690
+ with gr.Column():
691
+ edit_output_image = gr.Image(label=get_text("edited_image"))
692
+ edit_status = gr.Textbox(label=get_text("edit_status"), interactive=False)
693
+
694
+
695
+ with gr.Tab(get_text("tab_image_understanding")) as understand_tab:
696
+ with gr.Row():
697
+ with gr.Column():
698
+ understand_input_image = gr.Image(
699
+ label=get_text("input_image"),
700
+ value=load_example_image('assets/understand.png')
701
+ )
702
+ understand_question = gr.Textbox(
703
+ label=get_text("question"),
704
+ value=get_text("question_placeholder"),
705
+ lines=2
706
+ )
707
+
708
+ with gr.Accordion(get_text("advanced_params"), open=False):
709
+ understand_max_tokens = gr.Slider(
710
+ minimum=64, maximum=1024, value=256, step=64,
711
+ label=get_text("max_generation_length")
712
+ )
713
+
714
+ understand_btn = gr.Button(get_text("understand_image"), variant="primary")
715
+
716
+ with gr.Column():
717
+ understand_output = gr.Textbox(
718
+ label=get_text("understanding_result"),
719
+ lines=15,
720
+ interactive=False
721
+ )
722
+
723
+ usage_md = gr.Markdown(
724
+ f"""
725
+ ---
726
+ ### {get_text("usage_instructions")}
727
+ {get_text("usage_step1")}
728
+ {get_text("usage_step2")}
729
+ {get_text("usage_step3")}
730
+ """
731
+ )
732
+
733
+ txt_generate_btn.click(
734
+ fn=text_to_image,
735
+ inputs=[txt_prompt, txt_cfg_scale, txt_topk, txt_topp],
736
+ outputs=[txt_output_image, txt_status]
737
+ )
738
+
739
+ edit_btn.click(
740
+ fn=image_editing,
741
+ inputs=[edit_input_image, edit_instruction, edit_cfg_scale, edit_topk, edit_topp],
742
+ outputs=[edit_output_image, edit_status]
743
+ )
744
+
745
+ understand_btn.click(
746
+ fn=image_understanding,
747
+ inputs=[understand_input_image, understand_question, understand_max_tokens],
748
+ outputs=understand_output
749
+ )
750
+
751
+
752
+ language_dropdown.change(
753
+ fn=update_interface_language,
754
+ inputs=[language_dropdown],
755
+ outputs=[language_state, title_md, desc_md, txt_prompt, edit_instruction, understand_question, usage_md, txt_status]
756
+ )
757
+
758
+ return demo
759
+
760
+ demo = create_interface()
761
+
762
+ demo.launch(share=True, show_error=True)
763
+
assets/editing.png ADDED

Git LFS Details

  • SHA256: 725278dda08a4ce97589396aac69bb0c703b05d9de861fb9b278444f5b936af5
  • Pointer size: 131 Bytes
  • Size of remote file: 591 kB
assets/understand.png ADDED

Git LFS Details

  • SHA256: cb0fe61f3b81bc2ffbc0f5838d745881f0e77b95c7e3fd96bf43e7715dfd7fd8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.73 MB
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ torch>=2.2.2
3
+ torchvision>=0.17.2
4
+ torchaudio>=2.2.2
5
+ transformers==4.51.0
6
+ diffusers==0.33.0
7
+ decord>=0.6.0
8
+ attrdict2
9
+ accelerate>=0.32.0
10
+ timm>=1.0.15
11
+ opencv-python>=4.10.0
12
+ pillow>=10.4.0
13
+ einops>=0.8.0
14
+ xformers>=0.0.28
15
+ numpy>=1.26.0
16
+ pandas>=2.2.0
17
+ datasets>=3.0.0
18
+ tokenizers>=0.21.0
19
+ sentencepiece>=0.1.99
20
+ torchmetrics>=1.4.0
21
+ tqdm>=4.66.0
22
+ pyyaml>=6.0.0
23
+ requests>=2.32.0
24
+ packaging>=24.1
25
+ ipython>=8.26.0
26
+ matplotlib>=3.9.0
27
+ deepspeed>=0.14.4
28
+ wandb>=0.16.3
29
+ gradio>=5.34.0
30
+ qwen-vl-utils
star/.DS_Store ADDED
Binary file (6.15 kB). View file
 
star/configs/STAR_Qwen2.5-VL-3B.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "STAR_Qwen2.5-3B_VQGAN",
3
+ "model_type": "STARMultiModalityConfig",
4
+ "language_model": {
5
+ "model_name": "Qwen2.5-VL",
6
+ "model_path": "checkpoints/Qwen2.5-VL-3B-Instruct"
7
+ },
8
+ "pixel_encoder": {
9
+ "model_name": "VQ_Model",
10
+ "model_path": "checkpoints/VQ-Model.pt",
11
+ "image_token_size": 65536,
12
+ "n_embed": 512,
13
+ "num_tokens": 576,
14
+ "num_heads": 8
15
+ },
16
+ "pixel_adapter": {
17
+ "model_name": "MLP_GELU",
18
+ "depth": 2,
19
+ "input_dim": 512,
20
+ "n_embed": 2048
21
+ },
22
+ "stacked_ar": {
23
+ "num_layers": 16
24
+ },
25
+ "pixel_output_head": {
26
+ "image_token_embed": 4096,
27
+ "image_token_size": 65536,
28
+ "n_embed": 2048
29
+ },
30
+ "pixel_decoder": {
31
+ "model_name": "LUMINA2",
32
+ "model_path": "checkpoints/lumina-image2"
33
+ },
34
+ "torch_dtype": "bfloat16"
35
+ }
star/configs/STAR_Qwen2.5-VL-7B.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "STAR_Qwen2.5-7B_VQGAN",
3
+ "model_type": "STARMultiModalityConfig",
4
+ "language_model": {
5
+ "model_name": "Qwen2.5-VL",
6
+ "model_path": "checkpoints/Qwen2.5-VL-7B-Instruct"
7
+ },
8
+ "pixel_encoder": {
9
+ "model_name": "VQ_Model",
10
+ "model_path": "checkpoints/VQ-Model.pt",
11
+ "image_token_size": 65536,
12
+ "n_embed": 512,
13
+ "num_tokens": 576,
14
+ "num_heads": 8
15
+ },
16
+ "pixel_adapter": {
17
+ "model_name": "MLP_GELU",
18
+ "depth": 4,
19
+ "input_dim": 512,
20
+ "n_embed": 3584
21
+ },
22
+ "stacked_ar": {
23
+ "num_layers": 14
24
+ },
25
+ "pixel_output_head": {
26
+ "image_token_embed": 4096,
27
+ "image_token_size": 65536,
28
+ "n_embed": 3584
29
+ },
30
+ "pixel_decoder": {
31
+ "model_name": "LUMINA2",
32
+ "model_path": "checkpoints/lumina-image2"
33
+ },
34
+ "torch_dtype": "bfloat16"
35
+ }
star/models/adapter/projector.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class MlpProjector(nn.Module):
6
+ def __init__(self, cfg):
7
+ super().__init__()
8
+
9
+ self.cfg = cfg
10
+
11
+ if cfg.model_name == "MLP_GELU":
12
+ mlp_depth = cfg.get("depth", 1)
13
+ modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
14
+ for _ in range(1, mlp_depth):
15
+ modules.append(nn.GELU())
16
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
17
+ modules = nn.Sequential(*modules)
18
+
19
+ else:
20
+ raise ValueError(f"Unknown projector type: {cfg.model_name}")
21
+
22
+ self.layers = modules
23
+
24
+ def forward(self, x):
25
+
26
+ return self.layers(x)
star/models/config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from attrdict2 import AttrDict
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ def load_config_from_json(json_path):
7
+ with open(json_path, "r") as f:
8
+ config_data = json.load(f)
9
+ return config_data
10
+
11
+ class STARMultiModalConfig(PretrainedConfig):
12
+ model_type = "STARMultiModal"
13
+
14
+ def __init__(self, **kwargs):
15
+ super().__init__(**kwargs)
16
+
17
+ self.pixel_encoder = AttrDict(kwargs.get("pixel_encoder", {}))
18
+ self.pixel_adapter = AttrDict(kwargs.get("pixel_adapter", {}))
19
+ self.pixel_output_head = AttrDict(kwargs.get("pixel_output_head", {}))
20
+ self.language_model = AttrDict(kwargs.get("language_model", {}))
21
+ self.stacked_ar = AttrDict(kwargs.get("stacked_ar", {}))
22
+ self.pixel_decoder = AttrDict(kwargs.get("pixel_decoder", {}))
23
+
star/models/data_process_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch
4
+ import torchvision
5
+ from torchvision import transforms
6
+
7
+ BACKGROUND_COLOR=(127, 127, 127)
8
+
9
+ from torchvision.transforms import InterpolationMode
10
+
11
+ def preprocess_image_with_min_size(image, min_factor=28):
12
+ width, height = image.size
13
+ if height < min_factor or width < min_factor:
14
+ scale_factor = max(min_factor / height, min_factor / width)
15
+ new_width = int(width * scale_factor)
16
+ new_height = int(height * scale_factor)
17
+
18
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
19
+ return image
20
+
21
+ def preprocess_image_gen(images, processor, vq_transform):
22
+
23
+ image_list = []
24
+ grid_thw_list = []
25
+ vq_image_list = []
26
+ for image in images:
27
+ image = preprocess_image_with_min_size(image)
28
+
29
+ visual_processed = processor.preprocess(image, return_tensors="pt")
30
+ image_tensor = visual_processed["pixel_values"]
31
+ if isinstance(image_tensor, list):
32
+ image_tensor = image_tensor[0]
33
+ image_list.append(image_tensor)
34
+
35
+ grid_thw = visual_processed["image_grid_thw"][0]
36
+ grid_thw_list.append(grid_thw)
37
+
38
+ vq_image = vq_transform(image)
39
+ vq_image_list.append(vq_image)
40
+
41
+ image_tensor = torch.stack(image_list, dim=0)
42
+ grid_thw = torch.stack(grid_thw_list, dim=0)
43
+ vq_image = torch.stack(vq_image_list, dim=0)
44
+
45
+ return {
46
+ "pixel_values": image_tensor,
47
+ "image_grid_thw": grid_thw,
48
+ "vq_pixel_values": vq_image
49
+ }
50
+
51
+
52
+
53
+ def get_vq_transform(args):
54
+ return transforms.Compose([
55
+ transforms.Resize((args.vq_image_size, args.vq_image_size), interpolation=InterpolationMode.BILINEAR),
56
+ transforms.ToTensor(), # [0, 255] -> [0, 1]
57
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # [0, 1] -> [-1, 1]
58
+ ])
59
+
60
+ def get_full_transform(args):
61
+ return transforms.Compose([
62
+ transforms.Resize((1024, 1024), interpolation=InterpolationMode.BILINEAR),
63
+ transforms.ToTensor(), # [0, 255] -> [0, 1]
64
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # [0, 1] -> [-1, 1]
65
+ ])
star/models/model.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import requests
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as T
10
+ from torchvision.transforms.functional import InterpolationMode
11
+ from torch.nn import CrossEntropyLoss
12
+ from transformers import (
13
+ AutoConfig,
14
+ AutoTokenizer,
15
+ AutoModelForCausalLM,
16
+ PreTrainedModel
17
+ )
18
+
19
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, Qwen2VLProcessor
20
+
21
+ from star.models.config import STARMultiModalConfig
22
+ from star.models.pixel_encoder.vq_model import VQ_Model
23
+ from star.models.adapter.projector import MlpProjector
24
+ from star.models.pixel_decoder.lumina2_decoder import Lumina2Decoder
25
+ from star.models.data_process_utils import get_full_transform, get_vq_transform, preprocess_image_gen
26
+ from star.models.rope_2d import get_rope_index_25
27
+
28
+ class STARMultiModal(PreTrainedModel):
29
+ def __init__(self, config: STARMultiModalConfig, args=None, **kwargs):
30
+ super().__init__(config)
31
+
32
+ self.config = config
33
+ self.args = args if args is not None else kwargs.get("args", None)
34
+
35
+ # Pixel Encoder Generation
36
+ model_name = config.pixel_encoder.model_name
37
+ if model_name == "VQ_Model":
38
+ self.pixel_encoder = VQ_Model(config.pixel_encoder)
39
+ else:
40
+ assert None, f"Unsupported {model_name}"
41
+ self.pixel_encoder.eval()
42
+
43
+
44
+ # Pixel Adapter Generation
45
+ model_name = config.pixel_adapter.model_name
46
+ if model_name == "MLP_GELU":
47
+ self.pixel_adapter = MlpProjector(config.pixel_adapter)
48
+ else:
49
+ assert None, f"Unsupported {model_name}"
50
+
51
+ # Pixel Ouput Head Generation
52
+ self.pixel_output_head = torch.nn.Linear(config.pixel_output_head.n_embed, config.pixel_output_head.image_token_size)
53
+
54
+ if getattr(args, "diffusion_as_decoder") and args.diffusion_as_decoder:
55
+ self.diffusion_decoder = Lumina2Decoder(config.pixel_decoder, args)
56
+ else:
57
+ self.diffusion_decoder = None
58
+
59
+ # Large Language Model
60
+ model_name, model_path = config.language_model.model_name, config.language_model.model_path
61
+
62
+ if model_name == "Qwen2.5-VL":
63
+ self.llm = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="cuda")
64
+ self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
65
+ self.tokenizer = self.processor.tokenizer
66
+
67
+ self.image_processor = self.processor.image_processor
68
+ self.image_processor.max_pixels = self.args.max_pixels
69
+ self.image_processor.min_pixels = self.args.min_pixels
70
+ self.image_processor.size["longest_edge"] = self.args.max_pixels
71
+ self.image_processor.size["shortest_edge"] = self.args.min_pixels
72
+
73
+ special_token_tags = ["<|vision_start|>", "<|vision_pad|>", "<|image_pad|>", "<|vision_end|>", "<|fim_pad|>"]
74
+ self.special_tokens = {tag: self.tokenizer.vocab.get(tag, None) for tag in special_token_tags}
75
+
76
+ else:
77
+ assert None, f"unsupported {model_name}: {model_path}"
78
+ self.llm.generation_config.pad_token_id = self.tokenizer.encode(self.tokenizer.pad_token)[0]
79
+
80
+ if self.args.grad_ckpt:
81
+ self.llm.gradient_checkpointing_enable()
82
+ self.llm.visual.gradient_checkpointing_enable()
83
+
84
+
85
+ stacked_ar_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
86
+ num_layers_to_extract = config.stacked_ar.num_layers
87
+ stacked_ar_config.num_hidden_layers = num_layers_to_extract
88
+
89
+ self.stacked_ar = Qwen2_5_VLForConditionalGeneration(stacked_ar_config)
90
+
91
+ temp_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
92
+ total_layers = len(temp_model.model.layers)
93
+ start_layer = max(0, total_layers - num_layers_to_extract)
94
+ temp_model.model.layers = temp_model.model.layers[start_layer:]
95
+ self.stacked_ar.load_state_dict(temp_model.state_dict(), strict=False)
96
+
97
+ self.stacked_ar = self.stacked_ar.to("cuda")
98
+ del self.stacked_ar.visual, self.stacked_ar.model.embed_tokens, self.stacked_ar.lm_head
99
+
100
+
101
+ # For Inference Generation
102
+ def generate_images(self, prompt, max_new_tokens=256, num_return_sequences=1, cfg_weight=5.0, topk_sample=1000, topp_sample=1.0, temperature=1.0, reasoning=False, return_dict=False):
103
+
104
+ if reasoning:
105
+ return self.generate_images_reasoning(prompt, max_new_tokens, num_return_sequences, cfg_weight, topk_sample, topp_sample, temperature, return_dict)
106
+
107
+ messages = [{'role': 'user', 'content': prompt}]
108
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
109
+ text_token = self.tokenizer.encode(text)
110
+ text_token = torch.tensor(text_token).long().to(self.device)
111
+
112
+ keys = list(self.special_tokens.keys())
113
+ start_token = (torch.ones(1) * self.special_tokens.get(keys[0])).long().to(self.device)
114
+
115
+ input_ids = torch.cat((text_token, start_token)).long().to(self.device)
116
+ tokens = torch.zeros((num_return_sequences*2, len(input_ids)), dtype=torch.int).cuda()
117
+ assistant_tokens = input_ids[-4:]
118
+
119
+ for i in range(num_return_sequences*2):
120
+ tokens[i, :] = input_ids
121
+ if i % 2 != 0:
122
+ tokens[i, 1:-1] = self.special_tokens.get(keys[4])
123
+ tokens[i, -4:] = assistant_tokens
124
+
125
+ inputs_embeds = self.llm.model.embed_tokens(tokens).to(self.device)
126
+ generated_tokens = torch.zeros((num_return_sequences, max_new_tokens), dtype=torch.int).cuda()
127
+
128
+ for i in range(max_new_tokens):
129
+ outputs = self.llm.model(
130
+ inputs_embeds=inputs_embeds,
131
+ use_cache=True,
132
+ past_key_values=outputs.past_key_values if i != 0 else None,
133
+ output_hidden_states=True)
134
+ last_hidden_states = outputs[0]
135
+
136
+ output_states = self.stacked_ar.model(
137
+ inputs_embeds=last_hidden_states,
138
+ past_key_values=output_states.past_key_values if i != 0 else None,
139
+ output_hidden_states=True,
140
+ use_cache=True)
141
+
142
+ last_hidden_states = output_states.hidden_states[-1]
143
+
144
+ logits = self.pixel_output_head(last_hidden_states[:, -1, :])
145
+ logit_cond = logits[0::2, :]
146
+ logit_uncond = logits[1::2, :]
147
+ logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
148
+ next_token, _ = self.sample(logits, temperature=1.0, top_k=topk_sample, top_p=topp_sample)
149
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
150
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
151
+
152
+ vqgan_embeds = self.pixel_encoder.get_codebook_entry(next_token)
153
+ img_embeds = self.pixel_adapter(vqgan_embeds)
154
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
155
+
156
+ latent_size = int(math.sqrt(max_new_tokens))
157
+ output_images = self.pixel_encoder.decode_code(generated_tokens.to(dtype=torch.int), shape=[num_return_sequences, self.pixel_encoder.config.codebook_embed_dim, latent_size, latent_size])
158
+ output_images = output_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
159
+
160
+ diff_images = None
161
+ if self.diffusion_decoder is not None:
162
+ gen_image_embeds = self.pixel_encoder.get_codebook_entry(generated_tokens)
163
+
164
+ if self.args.diffusion_resolution==512:
165
+ self.diffusion_decoder.pipe.transformer.config.sample_size=16
166
+ elif self.args.diffusion_resolution==1024:
167
+ self.diffusion_decoder.pipe.transformer.config.sample_size=32
168
+ diff_images = self.diffusion_decoder.pipe(
169
+ prompt,
170
+ num_inference_steps=40,
171
+ guidance_scale=4.5,
172
+ gen_image_embeds=gen_image_embeds, #gen_image_embeds,
173
+ control_emd="text",
174
+ ori_inp_way=self.diffusion_decoder.transformer.ori_inp_dit,
175
+ only_t2i="vqconcat",
176
+ img_guidance_scale=1.05,
177
+ height=self.args.diffusion_resolution,
178
+ width=self.args.diffusion_resolution
179
+ ).images
180
+ if return_dict:
181
+ return {"output_images": output_images, "generated_tokens": generated_tokens, "diff_images": diff_images}
182
+ return output_images
183
+
184
+ def answer_text_qwen_vl(self, question, max_new_tokens=256, do_sample=True):
185
+
186
+ messages = [
187
+ {
188
+ "role": "user",
189
+ "content": [
190
+ {"type": "text", "text": question},
191
+ ],
192
+ }
193
+ ]
194
+
195
+ # Preparation for inference
196
+ text = self.processor.apply_chat_template(
197
+ messages, tokenize=False, add_generation_prompt=True
198
+ )
199
+ # image_inputs, video_inputs = process_vision_info(messages)
200
+ inputs = self.processor(
201
+ text=[text],
202
+ images=None,
203
+ videos=None,
204
+ padding=True,
205
+ return_tensors="pt",
206
+ )
207
+ inputs = inputs.to(self.llm.device)
208
+
209
+ # Inference: Generation of the output
210
+ generated_ids = self.llm.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=do_sample)
211
+ generated_ids_trimmed = [
212
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
213
+ ]
214
+ output_text = self.processor.batch_decode(
215
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
216
+ )
217
+
218
+ return output_text[0] if output_text else ""
219
+
220
+ def generate_images_reasoning(self, prompt, max_new_tokens=256, num_return_sequences=1, cfg_weight=5.0, topk_sample=1000, topp_sample=1.0, temperature=1.0, return_dict=False):
221
+
222
+ messages = [{'role': 'user', 'content': prompt}]
223
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
224
+ text_token = self.tokenizer.encode(text)
225
+ text_token = torch.tensor(text_token).long().to(self.device)
226
+
227
+ keys = list(self.special_tokens.keys())
228
+ start_token = (torch.ones(1) * self.special_tokens.get(keys[0])).long().to(self.device)
229
+
230
+ input_ids = torch.cat((text_token, start_token)).long().to(self.device)
231
+ tokens = torch.zeros((num_return_sequences*2, len(input_ids)), dtype=torch.int).cuda()
232
+ assistant_tokens = input_ids[-4:]
233
+
234
+ for i in range(num_return_sequences*2):
235
+ tokens[i, :] = input_ids
236
+ if i % 2 != 0:
237
+ tokens[i, 1:-1] = self.special_tokens.get(keys[4])
238
+ tokens[i, -4:] = assistant_tokens
239
+
240
+ generated_tokens = torch.zeros((num_return_sequences, max_new_tokens), dtype=torch.int).cuda()
241
+ answer_tokens_list = self.answer_text_qwen_vl(prompt, do_sample=False)
242
+
243
+ if answer_tokens_list:
244
+ answer_tokens_list = self.tokenizer.encode(answer_tokens_list, add_special_tokens=False)
245
+ answer_tokens = torch.tensor([answer_tokens_list], device=self.device) # [1, seq_len]
246
+ magic_prompt = " Ultra HD, 4K, cinematic composition"
247
+
248
+
249
+ magic_prompt_tokens = self.tokenizer.encode(magic_prompt, add_special_tokens=False)
250
+ magic_prompt_tensor = torch.tensor([magic_prompt_tokens], device=self.device) # [1, magic_seq_len]
251
+
252
+ answer_tokens = torch.cat([answer_tokens, magic_prompt_tensor], dim=1) # [1, seq_len + magic_seq_len]
253
+ answer_prompt = self.tokenizer.decode(answer_tokens[0]).split("assistant\n")[-1] #hjc see
254
+
255
+ special_token = self.special_tokens.get(keys[4])
256
+ special_token_tensor = torch.tensor([[special_token]], device=self.device)
257
+ special_token_expanded = special_token_tensor.expand(-1, answer_tokens.size(1))
258
+
259
+ answer_tokens_with_special = torch.cat([answer_tokens, special_token_expanded], dim=0)
260
+
261
+ batch_size = tokens.size(0) # num_return_sequences*2
262
+ answer_tokens_expanded = answer_tokens_with_special.repeat(batch_size // 2, 1)
263
+
264
+ input_tokens = torch.cat((tokens[:, :14], answer_tokens_expanded, tokens[:, -6:]), dim=1)
265
+
266
+ else:
267
+ input_tokens = tokens
268
+ answer_prompt = None
269
+
270
+ inputs_embeds = self.llm.model.embed_tokens(input_tokens).to(self.device)
271
+
272
+ for i in range(max_new_tokens):
273
+ outputs = self.llm.model(
274
+ inputs_embeds=inputs_embeds,
275
+ use_cache=True,
276
+ past_key_values=outputs.past_key_values if i != 0 else None,
277
+ output_hidden_states=True)
278
+ last_hidden_states = outputs[0]
279
+
280
+ output_states = self.stacked_ar.model(
281
+ inputs_embeds=last_hidden_states,
282
+ past_key_values=output_states.past_key_values if i != 0 else None,
283
+ output_hidden_states=True,
284
+ use_cache=True)
285
+
286
+ last_hidden_states = output_states.hidden_states[-1]
287
+
288
+ logits = self.pixel_output_head(last_hidden_states[:, -1, :])
289
+ logit_cond = logits[0::2, :]
290
+ logit_uncond = logits[1::2, :]
291
+ logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
292
+ next_token, _ = self.sample(logits, temperature=1.0, top_k=topk_sample, top_p=topp_sample)
293
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
294
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
295
+
296
+ vqgan_embeds = self.pixel_encoder.get_codebook_entry(next_token)
297
+ img_embeds = self.pixel_adapter(vqgan_embeds)
298
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
299
+
300
+ latent_size = int(math.sqrt(max_new_tokens))
301
+ output_images = self.pixel_encoder.decode_code(generated_tokens.to(dtype=torch.int), shape=[num_return_sequences, self.pixel_encoder.config.codebook_embed_dim, latent_size, latent_size])
302
+ output_images = output_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
303
+
304
+ diff_images = None
305
+ if self.diffusion_decoder is not None:
306
+ gen_image_embeds = self.pixel_encoder.get_codebook_entry(generated_tokens)
307
+ diff_prompt = answer_prompt if answer_prompt else prompt
308
+ if self.args.diffusion_resolution==512:
309
+ self.diffusion_decoder.pipe.transformer.config.sample_size=16
310
+ elif self.args.diffusion_resolution==1024:
311
+ self.diffusion_decoder.pipe.transformer.config.sample_size=32
312
+ diff_images = self.diffusion_decoder.pipe(
313
+ diff_prompt,
314
+ num_inference_steps=40,
315
+ guidance_scale=4.5,
316
+ gen_image_embeds=gen_image_embeds, #gen_image_embeds,
317
+ control_emd="text",
318
+ ori_inp_way=self.diffusion_decoder.transformer.ori_inp_dit,
319
+ only_t2i="vqconcat",
320
+ img_guidance_scale=1.05,
321
+ height=self.args.diffusion_resolution,
322
+ width=self.args.diffusion_resolution
323
+ ).images
324
+ if return_dict:
325
+ return {"output_images":output_images,"generated_tokens":generated_tokens,"diff_images":diff_images,"answer_prompt":answer_prompt}
326
+ return output_images
327
+
328
+ def generate_images_edit(self, image, prompt, max_new_tokens=256, num_return_sequences=1, cfg_weight=5.0, topk_sample=1000, topp_sample=1.0, temperature=1.0,return_dict=False):
329
+
330
+ vq_image_transform = get_vq_transform(self.args)
331
+ full_image_transform = get_full_transform(self.args)
332
+
333
+ if isinstance(image, str):
334
+ image = Image.open(image).convert('RGB')
335
+ elif isinstance(image, list):
336
+ image = [each_image.convert('RGB') for each_image in image]
337
+ else:
338
+ image = image.convert('RGB')
339
+
340
+ messages = [{'role': 'user', 'content': prompt}]
341
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
342
+ text_token = self.tokenizer.encode(text)
343
+ text_token = torch.tensor(text_token).long().to(self.device)
344
+
345
+ keys = list(self.special_tokens.keys())
346
+ start_token = (torch.ones(1) * self.special_tokens.get(keys[0])).long().to(self.device)
347
+ user_prompt = "<|im_start|>user\n"
348
+ user_prompt_token = self.tokenizer.encode(user_prompt, add_special_tokens=False)
349
+ user_prompt_tensor = torch.tensor(user_prompt_token).long().to(self.device)
350
+ windows = text_token.unfold(0, len(user_prompt_tensor), 1)
351
+ matches = (windows == user_prompt_tensor).all(dim=1)
352
+ image_position = torch.where(matches)[0][0].item() + len(user_prompt_tensor)
353
+
354
+ input_ids = torch.cat((text_token, start_token)).long().to(self.device)
355
+ tokens = torch.zeros((num_return_sequences*2, len(input_ids)), dtype=torch.int).cuda()
356
+ assistant_tokens = input_ids[-4:]
357
+
358
+ for i in range(num_return_sequences*2):
359
+ tokens[i, :] = input_ids
360
+ if i % 2 != 0:
361
+ tokens[i, 1:-1] = self.special_tokens.get(keys[4])
362
+ tokens[i, -4:] = assistant_tokens
363
+
364
+ inputs_embeds = self.llm.model.embed_tokens(tokens).to(self.device)
365
+ position_ids = None
366
+
367
+ if image is not None:
368
+ image_info = preprocess_image_gen(image, self.image_processor, vq_image_transform)
369
+ image_embeds = self.llm.visual(image_info["pixel_values"].to(inputs_embeds.device,self.llm.visual.dtype), grid_thw=image_info["image_grid_thw"].to(inputs_embeds.device))
370
+ image_embeds = image_embeds[None,:].repeat(2, 1, 1).to(inputs_embeds.device, inputs_embeds.dtype)
371
+
372
+ vq_pixel_values = image_info["vq_pixel_values"].to(inputs_embeds.device)
373
+ B = inputs_embeds.size(0)
374
+ if len(vq_pixel_values.shape)==4:
375
+ vq_pixel_values = vq_pixel_values[:,None]
376
+ N = vq_pixel_values.size(1)
377
+ _, _, [_, _, vq_indices] = self.pixel_encoder.encode(vq_pixel_values.flatten(0, 1).bfloat16())
378
+ batch_size = vq_pixel_values.shape[0]
379
+ vq_indices = vq_indices.reshape(batch_size, N, vq_indices.shape[-1])
380
+ vqgan_dec_embeds = self.pixel_encoder.get_codebook_entry(vq_indices)
381
+ vq_embeds = self.pixel_adapter(vqgan_dec_embeds)
382
+ vq_embeds = vq_embeds.repeat(B, 1, 1, 1).to(inputs_embeds.device, inputs_embeds.dtype).flatten(1, 2)
383
+
384
+ vision_start_embeds = self.llm.model.embed_tokens(torch.tensor(self.tokenizer.encode("<|vision_start|>")).long().to(self.device))
385
+ vision_end_embeds = self.llm.model.embed_tokens(torch.tensor(self.tokenizer.encode("<|vision_end|>")).long().to(self.device))
386
+ newline_embeds = self.llm.model.embed_tokens(torch.tensor(self.tokenizer.encode("\n")).long().to(self.device))
387
+ vision_start_embeds = vision_start_embeds.unsqueeze(0).repeat(B, 1, 1)
388
+ vision_end_embeds = vision_end_embeds.unsqueeze(0).repeat(B, 1, 1)
389
+ newline_embeds = newline_embeds.unsqueeze(0).repeat(B, 1, 1)
390
+
391
+ inputs_embeds = torch.cat((inputs_embeds[:, :image_position],
392
+ vision_start_embeds, vq_embeds, vision_end_embeds,
393
+ vision_start_embeds, image_embeds, vision_end_embeds, newline_embeds,
394
+ inputs_embeds[:, image_position:]), dim=1)
395
+
396
+ SPECIAL_VQ_TOKEN = '<|vision_pad|>'
397
+ SPECIAL_VIT_TOKEN = '<|image_pad|>'
398
+ SPECIAL_VQ_TOKEN_ID = self.tokenizer.encode(SPECIAL_VQ_TOKEN)[0]
399
+ SPECIAL_VIT_TOKEN_ID = self.tokenizer.encode(SPECIAL_VIT_TOKEN)[0]
400
+ input_ids_for_position = torch.cat([input_ids[:image_position],
401
+ torch.tensor(self.tokenizer.encode("<|vision_start|>")).to(vq_embeds.device), torch.full((vq_embeds.shape[-2],), SPECIAL_VQ_TOKEN_ID, device=vq_embeds.device), torch.tensor(self.tokenizer.encode("<|vision_end|>")).to(vq_embeds.device),
402
+ torch.tensor(self.tokenizer.encode("<|vision_start|>")).to(vq_embeds.device), torch.full((image_embeds.shape[-2],), SPECIAL_VIT_TOKEN_ID, device=vq_embeds.device), torch.tensor(self.tokenizer.encode("<|vision_end|>")).to(vq_embeds.device), torch.tensor(self.tokenizer.encode("\n")).to(vq_embeds.device),
403
+ input_ids[image_position:],torch.full((vq_embeds.shape[-2],), SPECIAL_VQ_TOKEN_ID, device=vq_embeds.device)], dim=0)
404
+ position_ids, _ = get_rope_index_25(
405
+ self.image_processor.merge_size,
406
+ input_ids_for_position[None],
407
+ image_grid_thw=image_info["image_grid_thw"],
408
+ video_grid_thw=None,
409
+ second_per_grid_ts=None,
410
+ )
411
+
412
+ generated_tokens = torch.zeros((num_return_sequences, max_new_tokens), dtype=torch.int).cuda()
413
+
414
+ for i in range(max_new_tokens):
415
+ if i != 0:
416
+ real_position = position_ids[:,:,outputs.past_key_values.seen_tokens:(outputs.past_key_values.seen_tokens+inputs_embeds.shape[1])].to(inputs_embeds.device)
417
+ else:
418
+ real_position = position_ids[:,:,:inputs_embeds.shape[1]].to(inputs_embeds.device)
419
+ outputs = self.llm.model(
420
+ inputs_embeds=inputs_embeds,
421
+ use_cache=True,
422
+ position_ids = real_position,
423
+ past_key_values=outputs.past_key_values if i != 0 else None,
424
+ output_hidden_states=True)
425
+ last_hidden_states = outputs[0]
426
+
427
+ output_states = self.stacked_ar.model(
428
+ inputs_embeds=last_hidden_states,
429
+ past_key_values=output_states.past_key_values if i != 0 else None,
430
+ output_hidden_states=True,
431
+ position_ids = real_position,
432
+ use_cache=True)
433
+
434
+ last_hidden_states = output_states.hidden_states[-1]
435
+
436
+ logits = self.pixel_output_head(last_hidden_states[:, -1, :])
437
+ logit_cond = logits[0::2, :]
438
+ logit_uncond = logits[1::2, :]
439
+ logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
440
+ next_token, _ = self.sample(logits, temperature=1.0, top_k=topk_sample, top_p=topp_sample)
441
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
442
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
443
+
444
+
445
+ vqgan_embeds = self.pixel_encoder.get_codebook_entry(next_token)
446
+ img_embeds = self.pixel_adapter(vqgan_embeds)
447
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
448
+
449
+ latent_size = int(math.sqrt(max_new_tokens))
450
+ output_images = self.pixel_encoder.decode_code(generated_tokens.to(dtype=torch.int), shape=[num_return_sequences, self.pixel_encoder.config.codebook_embed_dim, latent_size, latent_size])
451
+ output_images = output_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
452
+
453
+ diff_images = None
454
+ if self.diffusion_decoder is not None:
455
+
456
+ gen_image_embeds = self.pixel_encoder.get_codebook_entry(generated_tokens)
457
+
458
+ if isinstance(image, list):
459
+ processed_img = [full_image_transform(each_image) for each_image in image]
460
+ else:
461
+ processed_img = [full_image_transform(image)]
462
+ if self.args.diffusion_resolution==512:
463
+ self.diffusion_decoder.pipe.transformer.config.sample_size=16
464
+ elif self.args.diffusion_resolution==1024:
465
+ self.diffusion_decoder.pipe.transformer.config.sample_size=32
466
+ diff_images = self.diffusion_decoder.pipe(
467
+ prompt,
468
+ num_inference_steps=50,
469
+ guidance_scale=3.0,
470
+ gen_image_embeds=gen_image_embeds, #gen_image_embeds,
471
+ control_emd="text",ori_inp_img=processed_img[0],ori_inp_way="seq",
472
+ only_t2i="vqconcat",img_guidance_scale=1.8,vq_guidance_scale=1,height=self.args.diffusion_resolution,width=self.args.diffusion_resolution
473
+ ).images
474
+ if return_dict:
475
+ return {"output_images": output_images, "generated_tokens": None, "diff_images": diff_images}
476
+ return None
477
+
478
+ def sample(self, logits, temperature: float=1.0, top_k: int=0, top_p: float=1.0, sample_logits=True):
479
+
480
+ logits = logits / max(temperature, 1e-5)
481
+ if top_k > 0 or top_p < 1.0:
482
+ logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
483
+ probs = F.softmax(logits, dim=-1)
484
+ if sample_logits:
485
+ idx = torch.multinomial(probs, num_samples=1)
486
+ else:
487
+ _, idx = torch.topk(probs, k=1, dim=-1)
488
+ return idx, probs
489
+
490
+ def top_k_top_p_filtering(
491
+ self,
492
+ logits,
493
+ top_k: int = 0,
494
+ top_p: float = 1.0,
495
+ filter_value: float = -float("Inf"),
496
+ min_tokens_to_keep: int = 1,
497
+ ):
498
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
499
+ """
500
+ if top_k > 0:
501
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
502
+ # Remove all tokens with a probability less than the last token of the top-k
503
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
504
+ logits[indices_to_remove] = filter_value
505
+
506
+ if top_p < 1.0:
507
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
508
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
509
+
510
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
511
+ sorted_indices_to_remove = cumulative_probs > top_p
512
+ if min_tokens_to_keep > 1:
513
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
514
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
515
+ # Shift the indices to the right to keep also the first token above the threshold
516
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
517
+ sorted_indices_to_remove[..., 0] = 0
518
+
519
+ # scatter sorted tensors to original indexing
520
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
521
+ logits[indices_to_remove] = filter_value
522
+ return logits
523
+
524
+ # For Inference Understand
525
+ def preprocess_image(self, image):
526
+ if image is None:
527
+ return None
528
+ if isinstance(image, str):
529
+ if os.path.exists(image):
530
+ pil_image = Image.open(image).convert('RGB')
531
+ else:
532
+ response = requests.get(image)
533
+ if response.status_code == 200:
534
+ image_bytes = BytesIO(response.content)
535
+ pil_image = Image.open(image_bytes).convert('RGB')
536
+ else:
537
+ raise ValueError(f"Failed to load image from url {image}")
538
+ elif isinstance(image, Image.Image):
539
+ pil_image = image.convert('RGB')
540
+ elif isinstance(image, list):
541
+ return self.preprocess_image(image[0])
542
+ else:
543
+ raise ValueError("Unsupported image type")
544
+
545
+ return pil_image
546
+
547
+ def inference_understand(self, image, question, max_new_tokens=256):
548
+ pil_image = self.preprocess_image(image)
549
+
550
+ messages = [
551
+ {
552
+ "role": "user",
553
+ "content": [
554
+ {
555
+ "type": "image",
556
+ "image": pil_image,
557
+ },
558
+ {"type": "text", "text": question},
559
+ ],
560
+ }
561
+ ]
562
+
563
+ from qwen_vl_utils import process_vision_info
564
+ # Preparation for inference
565
+ text = self.processor.apply_chat_template(
566
+ messages, tokenize=False, add_generation_prompt=True
567
+ )
568
+ image_inputs, video_inputs = process_vision_info(messages)
569
+ inputs = self.processor(
570
+ text=[text],
571
+ images=image_inputs,
572
+ videos=video_inputs,
573
+ padding=True,
574
+ return_tensors="pt",
575
+ )
576
+ inputs = inputs.to(self.llm.device)
577
+
578
+ # Inference: Generation of the output
579
+ generated_ids = self.llm.generate(**inputs, max_new_tokens=max_new_tokens)
580
+ generated_ids_trimmed = [
581
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
582
+ ]
583
+ output_text = self.processor.batch_decode(
584
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
585
+ )
586
+
587
+ return output_text[0] if output_text else ""
star/models/pixel_decoder/lumina2_decoder.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, Lumina2Pipeline
3
+ from transformers import AutoTokenizer, Gemma2Model
4
+ import copy
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from diffusers.training_utils import (
8
+ cast_training_params,
9
+ compute_density_for_timestep_sampling,
10
+ compute_loss_weighting_for_sd3,
11
+ free_memory,
12
+ )
13
+ from diffusers.pipelines.lumina2.pipeline_lumina2 import *
14
+
15
+ class Lumina2Decoder(torch.nn.Module):
16
+ def __init__(self, config, args):
17
+ super().__init__()
18
+ self.diffusion_model_path = config.model_path
19
+
20
+ if not hasattr(args, "revision"):
21
+ args.revision = None
22
+ if not hasattr(args, "variant"):
23
+ args.variant = None
24
+
25
+ self.tokenizer_one = AutoTokenizer.from_pretrained(
26
+ self.diffusion_model_path,
27
+ subfolder="tokenizer",
28
+ revision=args.revision,
29
+ )
30
+ self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
31
+ self.diffusion_model_path, subfolder="scheduler"
32
+ )
33
+ self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
34
+ self.text_encoder_one = Gemma2Model.from_pretrained(
35
+ self.diffusion_model_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
36
+ )
37
+ self.text_encoding_pipeline = Lumina2Pipeline.from_pretrained(
38
+ self.diffusion_model_path,
39
+ vae=None,
40
+ transformer=None,
41
+ text_encoder=self.text_encoder_one,
42
+ tokenizer=self.tokenizer_one,
43
+ )
44
+ self.vae = AutoencoderKL.from_pretrained(
45
+ self.diffusion_model_path,
46
+ subfolder="vae",
47
+ revision=args.revision,
48
+ variant=args.variant,
49
+ )
50
+ if args.ori_inp_dit=="seq":
51
+ from star.models.pixel_decoder.transformer_lumina2_seq import Lumina2Transformer2DModel
52
+ elif args.ori_inp_dit=="ref":
53
+ from star.models.pixel_decoder.transformer_lumina2 import Lumina2Transformer2DModel
54
+
55
+ self.transformer = Lumina2Transformer2DModel.from_pretrained(
56
+ self.diffusion_model_path, subfolder="transformer", revision=args.revision, variant=args.variant
57
+ )
58
+
59
+ vq_dim = 512
60
+ patch_size = self.transformer.config.patch_size
61
+ in_channels = vq_dim + self.transformer.config.in_channels # 48 for mask
62
+ out_channels = self.transformer.x_embedder.out_features
63
+
64
+ load_num_channel = self.transformer.config.in_channels * patch_size * patch_size
65
+ self.transformer.register_to_config(in_channels=in_channels)
66
+ transformer = self.transformer
67
+ with torch.no_grad():
68
+ new_proj = nn.Linear(
69
+ in_channels * patch_size * patch_size, out_channels, bias=True
70
+ )
71
+
72
+ new_proj.weight.zero_()
73
+
74
+ new_proj = new_proj.to(transformer.x_embedder.weight.dtype)
75
+ new_proj.weight[:, :load_num_channel].copy_(transformer.x_embedder.weight)
76
+ new_proj.bias.copy_(transformer.x_embedder.bias)
77
+ transformer.x_embedder = new_proj
78
+
79
+ self.ori_inp_dit = args.ori_inp_dit
80
+ if args.ori_inp_dit=="seq":
81
+ refiner_channels = transformer.noise_refiner[-1].dim
82
+ with torch.no_grad():
83
+ vae2cond_proj1 = nn.Linear(refiner_channels, refiner_channels, bias=True)
84
+ vae2cond_act = nn.GELU(approximate='tanh')
85
+ vae2cond_proj2 = nn.Linear(refiner_channels, refiner_channels, bias=False)
86
+ vae2cond_proj2.weight.zero_()
87
+
88
+ ori_inp_refiner = nn.Sequential(
89
+ vae2cond_proj1,
90
+ vae2cond_act,
91
+ vae2cond_proj2
92
+ )
93
+ transformer.ori_inp_refiner = ori_inp_refiner
94
+ transformer.ori_inp_dit = self.ori_inp_dit
95
+ elif args.ori_inp_dit=="ref":
96
+ transformer.initialize_ref_weights()
97
+ transformer.ori_inp_dit = self.ori_inp_dit
98
+
99
+ transformer.requires_grad_(True)
100
+
101
+ if args.grad_ckpt and args.diffusion_resolution==1024:
102
+ transformer.gradient_checkpointing = args.grad_ckpt
103
+ transformer.enable_gradient_checkpointing()
104
+
105
+ self.vae.requires_grad_(False)
106
+ self.vae.to(dtype=torch.float32)
107
+ self.args = args
108
+
109
+ self.pipe = Lumina2InstructPix2PixPipeline.from_pretrained(self.diffusion_model_path,
110
+ transformer=transformer,
111
+ text_encoder=self.text_encoder_one,
112
+ vae=self.vae,
113
+ torch_dtype=torch.bfloat16)
114
+
115
+
116
+ with torch.no_grad():
117
+ _, _, self.uncond_prompt_embeds, self.uncond_prompt_attention_mask = self.text_encoding_pipeline.encode_prompt(
118
+ "",
119
+ max_sequence_length=self.args.max_diff_seq_length,
120
+ )
121
+
122
+ def compute_text_embeddings(self,prompt, text_encoding_pipeline):
123
+ with torch.no_grad():
124
+ prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(
125
+ prompt,
126
+ max_sequence_length=self.args.max_diff_seq_length,
127
+ )
128
+ return prompt_embeds, prompt_attention_mask
129
+
130
+ def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
131
+ sigmas = self.noise_scheduler_copy.sigmas.to(dtype=dtype)
132
+ schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device=timesteps.device)
133
+ timesteps = timesteps
134
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
135
+
136
+ sigma = sigmas[step_indices].flatten()
137
+ while len(sigma.shape) < n_dim:
138
+ sigma = sigma.unsqueeze(-1)
139
+ return sigma
140
+
141
+ def forward(self, batch_gpu,batch, image_embeds):
142
+ args = self.args
143
+ pixel_values = batch_gpu["full_pixel_values"].to(dtype=self.vae.dtype) #aux_image
144
+ data_type = "t2i"
145
+ if len(pixel_values.shape)==5:
146
+ bs,num_img,c,h,w = pixel_values.shape
147
+ if num_img==2:
148
+ data_type = "edit"
149
+ pixel_values_ori_img = pixel_values[:,0]
150
+ pixel_values = pixel_values[:,-1]
151
+ pixel_values = F.interpolate(pixel_values, size=(self.args.diffusion_resolution, self.args.diffusion_resolution), mode='bilinear',align_corners=False)
152
+ if data_type=="edit" and self.ori_inp_dit!="none":
153
+ pixel_values_ori_img = F.interpolate(pixel_values_ori_img, size=(self.args.diffusion_resolution, self.args.diffusion_resolution), mode='bilinear', align_corners=False)
154
+ prompt = batch["prompts"]
155
+ bs,_,_,_ = pixel_values.shape
156
+ image_prompt_embeds = None
157
+ image_embeds_2d = image_embeds.reshape(bs, 24, 24, image_embeds.shape[-1]).permute(0, 3, 1, 2)
158
+ image_embeds_2d = F.interpolate(image_embeds_2d, size=(args.diffusion_resolution//8, args.diffusion_resolution//8), mode='bilinear', align_corners=False)
159
+
160
+ control_emd = args.control_emd
161
+ prompt_embeds, prompt_attention_mask = self.compute_text_embeddings(prompt, self.text_encoding_pipeline)
162
+ if control_emd=="mix":
163
+ prompt_embeds=torch.cat([prompt_embeds, image_prompt_embeds], dim=1) #use mix
164
+ elif control_emd=="null":
165
+ prompt_embeds = torch.zeros_like(prompt_embeds)
166
+ prompt_attention_mask = torch.ones_like(prompt_attention_mask)
167
+ elif control_emd=="text":
168
+ pass
169
+ elif control_emd=="vit" or control_emd=="vq" or control_emd=="vqvae" or control_emd=="vqconcat" or control_emd=="vqconcatvit":
170
+ prompt_embeds=image_prompt_embeds
171
+
172
+
173
+ latents = self.vae.encode(pixel_values).latent_dist.sample()
174
+ latents = (latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
175
+ latents = latents.to(dtype=image_embeds.dtype)
176
+
177
+ latents_ori_img = torch.zeros_like(latents)
178
+ if data_type=="edit" and self.ori_inp_dit!="none":
179
+ latents_ori_img = self.vae.encode(pixel_values_ori_img).latent_dist.sample()
180
+ latents_ori_img = (latents_ori_img - self.vae.config.shift_factor) * self.vae.config.scaling_factor
181
+ latents_ori_img = latents_ori_img.to(dtype=image_embeds.dtype)
182
+
183
+ # Sample noise that we'll add to the latents
184
+ noise = torch.randn_like(latents)
185
+ bsz = latents.shape[0]
186
+ # Sample a random timestep for each image
187
+ # for weighting schemes where we sample timesteps non-uniformly
188
+ u = compute_density_for_timestep_sampling(
189
+ weighting_scheme=args.weighting_scheme,
190
+ batch_size=bsz,
191
+ logit_mean=args.logit_mean,
192
+ logit_std=args.logit_std,
193
+ mode_scale=args.mode_scale,
194
+ )
195
+
196
+ indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
197
+ timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
198
+
199
+ # Add noise to the latents according to the noise magnitude at each timestep
200
+ # (this is the forward diffusion process)
201
+ sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype).to(device=noise.device)
202
+ #noisy_model_input = (1.0 - sigmas) * noise + sigmas * latents
203
+ noisy_model_input = sigmas * noise + (1-sigmas) * latents
204
+ #noisy_model_input + (1-sigmas)*(latents - noise) = latents
205
+ # Get the additional image embedding for conditioning.
206
+ # Instead of getting a diagonal Gaussian here, we simply take the mode.
207
+ original_image_embeds = image_embeds_2d
208
+
209
+ if args.conditioning_dropout_prob is not None:
210
+ random_p = torch.rand(bsz, device=latents.device)
211
+ # Sample masks for the edit prompts.
212
+ prompt_mask = random_p < 2 * args.uncondition_prob
213
+ prompt_mask = prompt_mask.reshape(bsz, 1, 1)
214
+ # Final text conditioning.
215
+ #prompt_embeds = torch.where(prompt_mask, torch.zeros_like(prompt_embeds), prompt_embeds)
216
+ prompt_embeds = torch.where(prompt_mask, self.uncond_prompt_embeds.repeat(prompt_embeds.shape[0],1,1).to(prompt_embeds.device), prompt_embeds)
217
+ prompt_attention_mask = torch.where(prompt_mask[:,0], self.uncond_prompt_attention_mask.repeat(prompt_embeds.shape[0],1).to(prompt_embeds.device), prompt_attention_mask)
218
+
219
+ # Sample masks for the original images.
220
+ #random_p_vq = torch.rand(bsz, device=latents.device)
221
+ image_mask_dtype = original_image_embeds.dtype
222
+ image_mask = 1 - (
223
+ (random_p <= args.conditioning_dropout_prob).to(image_mask_dtype)
224
+ )
225
+ image_mask = image_mask.reshape(bsz, 1, 1, 1)
226
+
227
+ if data_type=="edit":
228
+ image_mask=0
229
+ # Final image conditioning.
230
+ original_image_embeds = image_mask * original_image_embeds
231
+
232
+ ori_latent_mask = 1 - (
233
+ (random_p >= args.uncondition_prob).to(image_mask_dtype)
234
+ * (random_p < 3 * args.uncondition_prob).to(image_mask_dtype)
235
+ )
236
+ ori_latent_mask = ori_latent_mask.reshape(bsz, 1, 1, 1)
237
+ latents_ori_img = ori_latent_mask * latents_ori_img
238
+
239
+ concatenated_noisy_latents = torch.cat([noisy_model_input, original_image_embeds], dim=1)
240
+
241
+ ref_image_hidden_states = None
242
+ if self.ori_inp_dit=="dim":
243
+ concatenated_noisy_latents = torch.cat([concatenated_noisy_latents, latents_ori_img], dim=1)
244
+ elif self.ori_inp_dit=="seq":
245
+ latents_ori_img = torch.cat([latents_ori_img, original_image_embeds], dim=1)
246
+ concatenated_noisy_latents = torch.cat([concatenated_noisy_latents, latents_ori_img], dim=2)
247
+ elif self.ori_inp_dit=="ref":
248
+ latents_ori_img = torch.cat([latents_ori_img, original_image_embeds], dim=1)
249
+ ref_image_hidden_states = latents_ori_img[:,None]
250
+ # Predict the noise residual
251
+ # scale the timesteps (reversal not needed as we used a reverse lerp above already)
252
+ timesteps = 1-timesteps / self.noise_scheduler.config.num_train_timesteps #timesteps / self.noise_scheduler.config.num_train_timesteps
253
+ model_pred = self.transformer(
254
+ hidden_states=concatenated_noisy_latents,
255
+ timestep=timesteps,
256
+ encoder_hidden_states=prompt_embeds,
257
+ encoder_attention_mask=prompt_attention_mask,
258
+ # ref_image_hidden_states = ref_image_hidden_states,
259
+ return_dict=False,
260
+ )[0]
261
+ if self.ori_inp_dit=="seq":
262
+ model_pred = model_pred[:, :, :args.diffusion_resolution//8, :]
263
+
264
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
265
+ target = latents - noise
266
+ # Conditioning dropout to support classifier-free guidance during inference. For more details
267
+ # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
268
+
269
+ # Concatenate the `original_image_embeds` with the `noisy_latents`.
270
+
271
+ # Get the target for loss depending on the prediction type
272
+ loss = torch.mean(
273
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
274
+ 1,
275
+ )
276
+ loss = loss.mean()
277
+
278
+ loss_value = loss.item()
279
+
280
+ return loss
281
+
282
+
283
+ class Lumina2InstructPix2PixPipeline(Lumina2Pipeline):
284
+
285
+ @torch.no_grad()
286
+ def __call__(
287
+ self,
288
+ prompt: Union[str, List[str]] = None,
289
+ width: Optional[int] = None,
290
+ height: Optional[int] = None,
291
+ num_inference_steps: int = 30,
292
+ guidance_scale: float = 4.0,
293
+ negative_prompt: Union[str, List[str]] = None,
294
+ sigmas: List[float] = None,
295
+ num_images_per_prompt: Optional[int] = 1,
296
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
297
+ latents: Optional[torch.Tensor] = None,
298
+ prompt_embeds: Optional[torch.Tensor] = None,
299
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
300
+ prompt_attention_mask: Optional[torch.Tensor] = None,
301
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
302
+ output_type: Optional[str] = "pil",
303
+ return_dict: bool = True,
304
+ attention_kwargs: Optional[Dict[str, Any]] = None,
305
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
306
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
307
+ system_prompt: Optional[str] = None,
308
+ cfg_trunc_ratio=[0.0,1.0],
309
+ cfg_normalization: bool = False,
310
+ max_sequence_length: int = 256,
311
+ control_emd="text",
312
+ img_cfg_trunc_ratio =[0.0,1.0],
313
+ gen_image_embeds=None,only_t2i="vqconcat",image_prompt_embeds=None,ori_inp_img=None,img_guidance_scale=1.5,vq_guidance_scale=0,ori_inp_way="none",
314
+ ) -> Union[ImagePipelineOutput, Tuple]:
315
+
316
+ height = height or self.default_sample_size * self.vae_scale_factor
317
+ width = width or self.default_sample_size * self.vae_scale_factor
318
+ self._guidance_scale = guidance_scale
319
+ self._attention_kwargs = attention_kwargs
320
+
321
+ num_images_per_prompt = gen_image_embeds.shape[0] if gen_image_embeds is not None else image_prompt_embeds.shape[0]
322
+ # 1. Check inputs. Raise error if not correct
323
+ self.check_inputs(
324
+ prompt,
325
+ height,
326
+ width,
327
+ negative_prompt,
328
+ prompt_embeds=prompt_embeds,
329
+ negative_prompt_embeds=negative_prompt_embeds,
330
+ prompt_attention_mask=prompt_attention_mask,
331
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
332
+ max_sequence_length=max_sequence_length,
333
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
334
+ )
335
+
336
+ # 2. Define call parameters
337
+ if prompt is not None and isinstance(prompt, str):
338
+ batch_size = 1
339
+ elif prompt is not None and isinstance(prompt, list):
340
+ batch_size = len(prompt)
341
+ else:
342
+ batch_size = prompt_embeds.shape[0]
343
+
344
+ device = self._execution_device
345
+
346
+ # 3. Encode input prompt
347
+ (
348
+ prompt_embeds,
349
+ prompt_attention_mask,
350
+ negative_prompt_embeds,
351
+ negative_prompt_attention_mask,
352
+ ) = self.encode_prompt(
353
+ prompt,
354
+ self.do_classifier_free_guidance,
355
+ negative_prompt=negative_prompt,
356
+ num_images_per_prompt=num_images_per_prompt,
357
+ device=device,
358
+ prompt_embeds=prompt_embeds,
359
+ negative_prompt_embeds=negative_prompt_embeds,
360
+ prompt_attention_mask=prompt_attention_mask,
361
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
362
+ max_sequence_length=max_sequence_length,
363
+ system_prompt=system_prompt,
364
+ )
365
+
366
+
367
+ if gen_image_embeds is not None:
368
+ image_embeds_8=gen_image_embeds
369
+
370
+ if control_emd=="text":
371
+ pass
372
+ elif control_emd=="null":
373
+ prompt_embeds = torch.zeros_like(prompt_embeds)
374
+ prompt_attention_mask = torch.zeros_like(prompt_attention_mask)
375
+ negative_prompt_embeds = prompt_embeds
376
+ negative_prompt_attention_mask = prompt_attention_mask
377
+
378
+ if self.do_classifier_free_guidance:
379
+ prompt_embeds = torch.cat([negative_prompt_embeds,negative_prompt_embeds, prompt_embeds], dim=0)
380
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask,negative_prompt_attention_mask, prompt_attention_mask], dim=0)
381
+ # 4. Prepare latents.
382
+ latent_channels = self.vae.config.latent_channels #self.transformer.config.in_channels
383
+ latents = self.prepare_latents(
384
+ batch_size * num_images_per_prompt,
385
+ latent_channels,
386
+ height,
387
+ width,
388
+ prompt_embeds.dtype,
389
+ device,
390
+ generator,
391
+ latents,
392
+ )
393
+
394
+ latents_ori_img = torch.zeros_like(latents)
395
+ if ori_inp_img is not None and ori_inp_way !="none":
396
+ #fuck = torch.load(ori_inp_img).to(latents.device)
397
+ ori_inp_img = F.interpolate(ori_inp_img[None].to(latents.device,latents.dtype), size=(height,width), mode='bilinear',align_corners=False)
398
+ latents_ori_img = self.vae.encode(ori_inp_img).latent_dist.sample()
399
+ latents_ori_img = (latents_ori_img- self.vae.config.shift_factor) * self.vae.config.scaling_factor
400
+ latents_ori_img = latents_ori_img.to(dtype=latents.dtype)
401
+ if ori_inp_way !="none":
402
+ negative_latents_ori_img = torch.zeros_like(latents_ori_img).to(prompt_embeds.dtype)
403
+ latents_ori_img = torch.cat([negative_latents_ori_img,latents_ori_img, latents_ori_img], dim=0) if self.do_classifier_free_guidance else latents_ori_img
404
+
405
+ vq_in_edit = False
406
+ if only_t2i==True:
407
+ image_latents = torch.zeros_like(latents)[:,:8]
408
+ elif only_t2i=="vqconcat":
409
+ image_embeds_2d = image_embeds_8.reshape(batch_size* num_images_per_prompt,24,24,image_embeds_8.shape[-1]).permute(0,3,1,2)
410
+ if ori_inp_img is not None and image_embeds_8.mean()!=0:
411
+ vq_in_edit = True
412
+ image_vq_latents = F.interpolate(image_embeds_2d, size=(height//8,width//8), mode='bilinear',align_corners=False).to(latents.device,latents.dtype)
413
+ image_latents = torch.zeros_like(image_vq_latents)
414
+ else:
415
+ image_latents = F.interpolate(image_embeds_2d, size=(height//8,width//8), mode='bilinear',align_corners=False).to(latents.device,latents.dtype)
416
+
417
+ negative_image_latents = torch.zeros_like(image_latents).to(prompt_embeds.dtype)
418
+ image_latents = torch.cat([negative_image_latents,image_latents, image_latents], dim=0) if self.do_classifier_free_guidance else image_latents
419
+
420
+ # 5. Prepare timesteps
421
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
422
+ image_seq_len = latents.shape[1]
423
+ mu = calculate_shift(
424
+ image_seq_len,
425
+ self.scheduler.config.get("base_image_seq_len", 256),
426
+ self.scheduler.config.get("max_image_seq_len", 4096),
427
+ self.scheduler.config.get("base_shift", 0.5),
428
+ self.scheduler.config.get("max_shift", 1.15),
429
+ )
430
+ timesteps, num_inference_steps = retrieve_timesteps(
431
+ self.scheduler,
432
+ num_inference_steps,
433
+ device,
434
+ sigmas=sigmas,
435
+ mu=mu,
436
+ )
437
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
438
+ self._num_timesteps = len(timesteps)
439
+
440
+ self.scheduler.sigmas=self.scheduler.sigmas.to(latents.dtype) #hjc find bug
441
+ # 6. Denoising loop
442
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
443
+ for i, t in enumerate(timesteps):
444
+ # compute whether apply classifier-free truncation on this timestep
445
+ do_classifier_free_truncation = not ((i + 1) / num_inference_steps > cfg_trunc_ratio[0] and (i + 1) / num_inference_steps < cfg_trunc_ratio[1])
446
+ img_do_classifier_free_truncation = not ((i + 1) / num_inference_steps > img_cfg_trunc_ratio[0] and (i + 1) / num_inference_steps < img_cfg_trunc_ratio[1])
447
+
448
+ # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
449
+ current_timestep = 1 - t / self.scheduler.config.num_train_timesteps
450
+
451
+ latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
452
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
453
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
454
+
455
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
456
+
457
+ ref_image_hidden_states = None
458
+ if ori_inp_way=="seq":
459
+ latents_ori_img_cat = torch.cat([latents_ori_img, image_latents], dim=1)
460
+ latent_model_input = torch.cat([latent_model_input, latents_ori_img_cat], dim=2)
461
+ elif ori_inp_way=="ref":
462
+ latents_ori_img_cat = torch.cat([latents_ori_img, image_latents], dim=1)
463
+ ref_image_hidden_states = latents_ori_img_cat[:,None]
464
+
465
+ if ori_inp_way=="ref":
466
+ noise_pred = self.transformer(
467
+ hidden_states=latent_model_input,
468
+ timestep=current_timestep,
469
+ encoder_hidden_states=prompt_embeds,
470
+ encoder_attention_mask=prompt_attention_mask,
471
+ return_dict=False,ref_image_hidden_states=ref_image_hidden_states,
472
+ attention_kwargs=self.attention_kwargs,
473
+ )[0]
474
+ else:
475
+ noise_pred = self.transformer(
476
+ hidden_states=latent_model_input,
477
+ timestep=current_timestep,
478
+ encoder_hidden_states=prompt_embeds,
479
+ encoder_attention_mask=prompt_attention_mask,
480
+ return_dict=False,
481
+ attention_kwargs=self.attention_kwargs,
482
+ )[0]
483
+ if ori_inp_way=="seq":
484
+ noise_pred = noise_pred[:,:,:height//8,:]
485
+
486
+ if vq_in_edit:
487
+ latent_model_vq_input = torch.cat([latents, image_vq_latents], dim=1)
488
+ if ori_inp_way=="seq":
489
+ latents_ori_img_cat_vq = torch.cat([torch.zeros_like(latents), image_vq_latents], dim=1)
490
+ latent_model_vq_input = torch.cat([latent_model_vq_input, latents_ori_img_cat_vq], dim=2)
491
+
492
+ noise_vq_pred = self.transformer(
493
+ hidden_states=latent_model_vq_input,
494
+ timestep=current_timestep[-1:],
495
+ encoder_hidden_states=prompt_embeds[-1:],
496
+ encoder_attention_mask=prompt_attention_mask[-1:],
497
+ return_dict=False,
498
+ attention_kwargs=self.attention_kwargs,
499
+ )[0]
500
+ if ori_inp_way=="seq":
501
+ noise_vq_pred = noise_vq_pred[:,:,:height//8,:]
502
+ # perform normalization-based guidance scale on a truncated timestep interval
503
+ if self.do_classifier_free_guidance:
504
+ noise_pred_uncond,noise_pred_img, noise_pred_text = noise_pred.chunk(3)
505
+ if not do_classifier_free_truncation and not img_do_classifier_free_truncation:
506
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_img)+ img_guidance_scale * (noise_pred_img - noise_pred_uncond)
507
+ elif not do_classifier_free_truncation and img_do_classifier_free_truncation:
508
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_img)+ 1 * (noise_pred_img - noise_pred_uncond)
509
+ elif do_classifier_free_truncation and not img_do_classifier_free_truncation:
510
+ noise_pred = noise_pred_uncond + 1 * (noise_pred_text - noise_pred_img)+ img_guidance_scale * (noise_pred_img - noise_pred_uncond)
511
+ else:
512
+ noise_pred = noise_pred_text
513
+ if vq_in_edit:
514
+ noise_pred = noise_pred +vq_guidance_scale*(noise_vq_pred-noise_pred_uncond)
515
+ # apply normalization after classifier-free guidance
516
+ if cfg_normalization:
517
+ cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True)
518
+ noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
519
+ noise_pred = noise_pred * (cond_norm / noise_norm)
520
+ else:
521
+ noise_pred = noise_pred
522
+
523
+ # compute the previous noisy sample x_t -> x_t-1
524
+ latents_dtype = latents.dtype
525
+ noise_pred = -noise_pred
526
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
527
+
528
+ if latents.dtype != latents_dtype:
529
+ if torch.backends.mps.is_available():
530
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
531
+ latents = latents.to(latents_dtype)
532
+
533
+ if callback_on_step_end is not None:
534
+ callback_kwargs = {}
535
+ for k in callback_on_step_end_tensor_inputs:
536
+ callback_kwargs[k] = locals()[k]
537
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
538
+
539
+ latents = callback_outputs.pop("latents", latents)
540
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
541
+
542
+ # call the callback, if provided
543
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
544
+ progress_bar.update()
545
+
546
+ if XLA_AVAILABLE:
547
+ xm.mark_step()
548
+
549
+ if not output_type == "latent":
550
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
551
+ image = self.vae.decode(latents, return_dict=False)[0]
552
+ image = self.image_processor.postprocess(image, output_type=output_type)
553
+ else:
554
+ image = latents
555
+
556
+ # Offload all models
557
+ self.maybe_free_model_hooks()
558
+
559
+ if not return_dict:
560
+ return (image,)
561
+
562
+ return ImagePipelineOutput(images=image)
563
+
star/models/pixel_decoder/transformer_lumina2.py ADDED
@@ -0,0 +1,770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from einops import rearrange
22
+ from diffusers.models.transformers.transformer_lumina2 import *
23
+ from einops import repeat
24
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
25
+ import itertools
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
31
+ def __init__(
32
+ self,
33
+ hidden_size: int = 4096,
34
+ cap_feat_dim: int = 2048,
35
+ frequency_embedding_size: int = 256,
36
+ norm_eps: float = 1e-5,
37
+ ) -> None:
38
+ super().__init__()
39
+
40
+ self.time_proj = Timesteps(
41
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
42
+ )
43
+
44
+ self.timestep_embedder = TimestepEmbedding(
45
+ in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
46
+ )
47
+
48
+ self.caption_embedder = nn.Sequential(
49
+ RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
50
+ )
51
+
52
+ def forward(
53
+ self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
54
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
55
+ timestep_proj = self.time_proj(timestep).type_as(hidden_states[0])
56
+ time_embed = self.timestep_embedder(timestep_proj)
57
+ caption_embed = self.caption_embedder(encoder_hidden_states)
58
+ return time_embed, caption_embed
59
+
60
+
61
+ class Lumina2AttnProcessor2_0:
62
+ r"""
63
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
64
+ used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
65
+ """
66
+
67
+ def __init__(self):
68
+ if not hasattr(F, "scaled_dot_product_attention"):
69
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
70
+
71
+ def __call__(
72
+ self,
73
+ attn: Attention,
74
+ hidden_states: torch.Tensor,
75
+ encoder_hidden_states: torch.Tensor,
76
+ attention_mask: Optional[torch.Tensor] = None,
77
+ image_rotary_emb: Optional[torch.Tensor] = None,
78
+ base_sequence_length: Optional[int] = None,
79
+ ) -> torch.Tensor:
80
+ batch_size, sequence_length, _ = hidden_states.shape
81
+
82
+ # Get Query-Key-Value Pair
83
+ query = attn.to_q(hidden_states)
84
+ key = attn.to_k(encoder_hidden_states)
85
+ value = attn.to_v(encoder_hidden_states)
86
+
87
+ query_dim = query.shape[-1]
88
+ inner_dim = key.shape[-1]
89
+ head_dim = query_dim // attn.heads
90
+ dtype = query.dtype
91
+
92
+ # Get key-value heads
93
+ kv_heads = inner_dim // head_dim
94
+
95
+ query = query.view(batch_size, -1, attn.heads, head_dim)
96
+ key = key.view(batch_size, -1, kv_heads, head_dim)
97
+ value = value.view(batch_size, -1, kv_heads, head_dim)
98
+
99
+ # Apply Query-Key Norm if needed
100
+ if attn.norm_q is not None:
101
+ query = attn.norm_q(query)
102
+ if attn.norm_k is not None:
103
+ key = attn.norm_k(key)
104
+
105
+ # Apply RoPE if needed
106
+ if image_rotary_emb is not None:
107
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
108
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
109
+
110
+ query, key = query.to(dtype), key.to(dtype)
111
+
112
+ # Apply proportional attention if true
113
+ if base_sequence_length is not None:
114
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
115
+ else:
116
+ softmax_scale = attn.scale
117
+
118
+ # perform Grouped-qurey Attention (GQA)
119
+ n_rep = attn.heads // kv_heads
120
+ if n_rep >= 1:
121
+ key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
122
+ value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
123
+
124
+ # scaled_dot_product_attention expects attention_mask shape to be
125
+ # (batch, heads, source_length, target_length)
126
+ if attention_mask is not None:
127
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
128
+
129
+ query = query.transpose(1, 2)
130
+ key = key.transpose(1, 2)
131
+ value = value.transpose(1, 2)
132
+
133
+ hidden_states = F.scaled_dot_product_attention(
134
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
135
+ )
136
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
137
+ hidden_states = hidden_states.type_as(query)
138
+
139
+ # linear proj
140
+ hidden_states = attn.to_out[0](hidden_states)
141
+ hidden_states = attn.to_out[1](hidden_states)
142
+ return hidden_states
143
+
144
+
145
+ class Lumina2TransformerBlock(nn.Module):
146
+ def __init__(
147
+ self,
148
+ dim: int,
149
+ num_attention_heads: int,
150
+ num_kv_heads: int,
151
+ multiple_of: int,
152
+ ffn_dim_multiplier: float,
153
+ norm_eps: float,
154
+ modulation: bool = True,
155
+ ) -> None:
156
+ super().__init__()
157
+ self.head_dim = dim // num_attention_heads
158
+ self.dim = dim
159
+ self.modulation = modulation
160
+
161
+ self.attn = Attention(
162
+ query_dim=dim,
163
+ cross_attention_dim=None,
164
+ dim_head=dim // num_attention_heads,
165
+ qk_norm="rms_norm",
166
+ heads=num_attention_heads,
167
+ kv_heads=num_kv_heads,
168
+ eps=1e-5,
169
+ bias=False,
170
+ out_bias=False,
171
+ processor=Lumina2AttnProcessor2_0(),
172
+ )
173
+
174
+ self.feed_forward = LuminaFeedForward(
175
+ dim=dim,
176
+ inner_dim=4 * dim,
177
+ multiple_of=multiple_of,
178
+ ffn_dim_multiplier=ffn_dim_multiplier,
179
+ )
180
+
181
+ if modulation:
182
+ self.norm1 = LuminaRMSNormZero(
183
+ embedding_dim=dim,
184
+ norm_eps=norm_eps,
185
+ norm_elementwise_affine=True,
186
+ )
187
+ else:
188
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
189
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
190
+
191
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
192
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
193
+
194
+ def forward(
195
+ self,
196
+ hidden_states: torch.Tensor,
197
+ attention_mask: torch.Tensor,
198
+ image_rotary_emb: torch.Tensor,
199
+ temb: Optional[torch.Tensor] = None,
200
+ ) -> torch.Tensor:
201
+ if self.modulation:
202
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
203
+ attn_output = self.attn(
204
+ hidden_states=norm_hidden_states,
205
+ encoder_hidden_states=norm_hidden_states,
206
+ attention_mask=attention_mask,
207
+ image_rotary_emb=image_rotary_emb,
208
+ )
209
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
210
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
211
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
212
+ else:
213
+ norm_hidden_states = self.norm1(hidden_states)
214
+ attn_output = self.attn(
215
+ hidden_states=norm_hidden_states,
216
+ encoder_hidden_states=norm_hidden_states,
217
+ attention_mask=attention_mask,
218
+ image_rotary_emb=image_rotary_emb,
219
+ )
220
+ hidden_states = hidden_states + self.norm2(attn_output)
221
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
222
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
223
+
224
+ return hidden_states
225
+
226
+
227
+ class Lumina2RotaryPosEmbed(nn.Module):
228
+ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
229
+ super().__init__()
230
+ self.theta = theta
231
+ self.axes_dim = axes_dim
232
+ self.axes_lens = axes_lens
233
+ self.patch_size = patch_size
234
+
235
+ self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
236
+
237
+ def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
238
+ freqs_cis = []
239
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
240
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
241
+ emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype)
242
+ freqs_cis.append(emb)
243
+ return freqs_cis
244
+
245
+ def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
246
+ device = ids.device
247
+ if ids.device.type == "mps":
248
+ ids = ids.to("cpu")
249
+
250
+ result = []
251
+ for i in range(len(self.axes_dim)):
252
+ freqs = self.freqs_cis[i].to(ids.device)
253
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
254
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
255
+ return torch.cat(result, dim=-1).to(device)
256
+
257
+ def forward(
258
+ self,
259
+ attention_mask,
260
+ l_effective_ref_img_len,
261
+ l_effective_img_len,
262
+ ref_img_sizes,
263
+ img_sizes,
264
+ device
265
+ ):
266
+
267
+ batch_size = len(attention_mask)
268
+ p = self.patch_size
269
+
270
+ encoder_seq_len = attention_mask.shape[1]
271
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
272
+
273
+ seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
274
+
275
+ max_seq_len = max(seq_lengths)
276
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
277
+ max_img_len = max(l_effective_img_len)
278
+
279
+ # Create position IDs
280
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
281
+
282
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
283
+ # add text position ids
284
+ position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
285
+
286
+ pe_shift = cap_seq_len
287
+ pe_shift_len = cap_seq_len
288
+
289
+ if ref_img_sizes[i] is not None:
290
+ for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
291
+ H, W = ref_img_size
292
+ ref_H_tokens, ref_W_tokens = H // p, W // p
293
+ assert ref_H_tokens * ref_W_tokens == ref_img_len
294
+ # add image position ids
295
+
296
+ row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
297
+ col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
298
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
299
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
300
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
301
+
302
+ pe_shift += max(ref_H_tokens, ref_W_tokens)
303
+ pe_shift_len += ref_img_len
304
+
305
+ H, W = img_sizes[i]
306
+ H_tokens, W_tokens = H // p, W // p
307
+ assert H_tokens * W_tokens == l_effective_img_len[i]
308
+
309
+ row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
310
+ col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
311
+
312
+ assert pe_shift_len + l_effective_img_len[i] == seq_len
313
+ position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
314
+ position_ids[i, pe_shift_len: seq_len, 1] = row_ids
315
+ position_ids[i, pe_shift_len: seq_len, 2] = col_ids
316
+
317
+ # Get combined rotary embeddings
318
+ freqs_cis = self._get_freqs_cis(position_ids)
319
+
320
+ # create separate rotary embeddings for captions and images
321
+ cap_freqs_cis = torch.zeros(
322
+ batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
323
+ )
324
+ ref_img_freqs_cis = torch.zeros(
325
+ batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
326
+ )
327
+ img_freqs_cis = torch.zeros(
328
+ batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
329
+ )
330
+
331
+ for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
332
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
333
+ ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
334
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
335
+
336
+ return (
337
+ cap_freqs_cis,
338
+ ref_img_freqs_cis,
339
+ img_freqs_cis,
340
+ freqs_cis,
341
+ l_effective_cap_len,
342
+ seq_lengths,
343
+ )
344
+
345
+
346
+ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
347
+ r"""
348
+ Lumina2NextDiT: Diffusion model with a Transformer backbone.
349
+
350
+ Parameters:
351
+ sample_size (`int`): The width of the latent images. This is fixed during training since
352
+ it is used to learn a number of position embeddings.
353
+ patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
354
+ The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
355
+ in_channels (`int`, *optional*, defaults to 4):
356
+ The number of input channels for the model. Typically, this matches the number of channels in the input
357
+ images.
358
+ hidden_size (`int`, *optional*, defaults to 4096):
359
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
360
+ hidden representations.
361
+ num_layers (`int`, *optional*, default to 32):
362
+ The number of layers in the model. This defines the depth of the neural network.
363
+ num_attention_heads (`int`, *optional*, defaults to 32):
364
+ The number of attention heads in each attention layer. This parameter specifies how many separate attention
365
+ mechanisms are used.
366
+ num_kv_heads (`int`, *optional*, defaults to 8):
367
+ The number of key-value heads in the attention mechanism, if different from the number of attention heads.
368
+ If None, it defaults to num_attention_heads.
369
+ multiple_of (`int`, *optional*, defaults to 256):
370
+ A factor that the hidden size should be a multiple of. This can help optimize certain hardware
371
+ configurations.
372
+ ffn_dim_multiplier (`float`, *optional*):
373
+ A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
374
+ the model configuration.
375
+ norm_eps (`float`, *optional*, defaults to 1e-5):
376
+ A small value added to the denominator for numerical stability in normalization layers.
377
+ scaling_factor (`float`, *optional*, defaults to 1.0):
378
+ A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
379
+ overall scale of the model's operations.
380
+ """
381
+
382
+ _supports_gradient_checkpointing = True
383
+ _no_split_modules = ["Lumina2TransformerBlock"]
384
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
385
+
386
+ @register_to_config
387
+ def __init__(
388
+ self,
389
+ sample_size: int = 128,
390
+ patch_size: int = 2,
391
+ in_channels: int = 16,
392
+ out_channels: Optional[int] = None,
393
+ hidden_size: int = 2304,
394
+ num_layers: int = 26,
395
+ num_refiner_layers: int = 2,
396
+ num_attention_heads: int = 24,
397
+ num_kv_heads: int = 8,
398
+ multiple_of: int = 256,
399
+ ffn_dim_multiplier: Optional[float] = None,
400
+ norm_eps: float = 1e-5,
401
+ scaling_factor: float = 1.0,
402
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
403
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
404
+ cap_feat_dim: int = 1024,
405
+ ) -> None:
406
+ super().__init__()
407
+ self.out_channels = out_channels or in_channels
408
+
409
+ # 1. Positional, patch & conditional embeddings
410
+ self.rope_embedder = Lumina2RotaryPosEmbed(
411
+ theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
412
+ )
413
+
414
+ self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
415
+
416
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
417
+ hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
418
+ )
419
+
420
+ # 2. Noise and context refinement blocks
421
+ self.noise_refiner = nn.ModuleList(
422
+ [
423
+ Lumina2TransformerBlock(
424
+ hidden_size,
425
+ num_attention_heads,
426
+ num_kv_heads,
427
+ multiple_of,
428
+ ffn_dim_multiplier,
429
+ norm_eps,
430
+ modulation=True,
431
+ )
432
+ for _ in range(num_refiner_layers)
433
+ ]
434
+ )
435
+
436
+ self.context_refiner = nn.ModuleList(
437
+ [
438
+ Lumina2TransformerBlock(
439
+ hidden_size,
440
+ num_attention_heads,
441
+ num_kv_heads,
442
+ multiple_of,
443
+ ffn_dim_multiplier,
444
+ norm_eps,
445
+ modulation=False,
446
+ )
447
+ for _ in range(num_refiner_layers)
448
+ ]
449
+ )
450
+
451
+ # 3. Transformer blocks
452
+ self.layers = nn.ModuleList(
453
+ [
454
+ Lumina2TransformerBlock(
455
+ hidden_size,
456
+ num_attention_heads,
457
+ num_kv_heads,
458
+ multiple_of,
459
+ ffn_dim_multiplier,
460
+ norm_eps,
461
+ modulation=True,
462
+ )
463
+ for _ in range(num_layers)
464
+ ]
465
+ )
466
+
467
+ # 4. Output norm & projection
468
+ self.norm_out = LuminaLayerNormContinuous(
469
+ embedding_dim=hidden_size,
470
+ conditioning_embedding_dim=min(hidden_size, 1024),
471
+ elementwise_affine=False,
472
+ eps=1e-6,
473
+ bias=True,
474
+ out_dim=patch_size * patch_size * self.out_channels,
475
+ )
476
+
477
+ self.gradient_checkpointing = False
478
+
479
+ self.args_dict = {"patch_size":patch_size,"in_channels":in_channels,"hidden_size":hidden_size,
480
+ "num_attention_heads":num_attention_heads,"num_kv_heads":num_kv_heads,
481
+ "multiple_of":multiple_of,"ffn_dim_multiplier":ffn_dim_multiplier,
482
+ "norm_eps":norm_eps,"num_refiner_layers":num_refiner_layers}
483
+
484
+ def initialize_ref_weights(self) -> None:
485
+ """
486
+ Initialize the weights of the model.
487
+
488
+ Uses Xavier uniform initialization for linear layers.
489
+ """
490
+ patch_size, in_channels, hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, num_refiner_layers = \
491
+ (self.args_dict[k] for k in ["patch_size","in_channels","hidden_size","num_attention_heads","num_kv_heads",
492
+ "multiple_of","ffn_dim_multiplier","norm_eps","num_refiner_layers"])
493
+ with torch.no_grad():
494
+ self.ref_image_patch_embedder = nn.Linear(
495
+ in_features=self.x_embedder.in_features,
496
+ out_features=hidden_size,
497
+ )
498
+ self.ref_image_refiner = nn.ModuleList([
499
+ Lumina2TransformerBlock(
500
+ hidden_size,
501
+ num_attention_heads,
502
+ num_kv_heads,
503
+ multiple_of,
504
+ ffn_dim_multiplier,
505
+ norm_eps,
506
+ modulation=True
507
+ )
508
+ for _ in range(num_refiner_layers)
509
+ ])
510
+ nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
511
+ nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
512
+
513
+ # Add learnable embeddings to distinguish different images
514
+ self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
515
+ nn.init.normal_(self.image_index_embedding, std=0.02)
516
+
517
+ def img_patch_embed_and_refine(
518
+ self,
519
+ hidden_states,
520
+ ref_image_hidden_states,
521
+ padded_img_mask,
522
+ padded_ref_img_mask,
523
+ noise_rotary_emb,
524
+ ref_img_rotary_emb,
525
+ l_effective_ref_img_len,
526
+ l_effective_img_len,
527
+ temb
528
+ ):
529
+ batch_size = len(hidden_states)
530
+ max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
531
+
532
+ hidden_states = self.x_embedder(hidden_states)
533
+ ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
534
+
535
+ for i in range(batch_size):
536
+ shift = 0
537
+ for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
538
+ ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
539
+ shift += ref_img_len
540
+
541
+ for layer in self.noise_refiner:
542
+ hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
543
+
544
+ flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
545
+ num_ref_images = len(flat_l_effective_ref_img_len)
546
+ max_ref_img_len = max(flat_l_effective_ref_img_len)
547
+
548
+ batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
549
+ batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
550
+ batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
551
+ batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
552
+
553
+ # sequence of ref imgs to batch
554
+ idx = 0
555
+ for i in range(batch_size):
556
+ shift = 0
557
+ for ref_img_len in l_effective_ref_img_len[i]:
558
+ batch_ref_img_mask[idx, :ref_img_len] = True
559
+ batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
560
+ batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
561
+ batch_temb[idx] = temb[i]
562
+ shift += ref_img_len
563
+ idx += 1
564
+
565
+ # refine ref imgs separately
566
+ for layer in self.ref_image_refiner:
567
+ batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
568
+
569
+ # batch of ref imgs to sequence
570
+ idx = 0
571
+ for i in range(batch_size):
572
+ shift = 0
573
+ for ref_img_len in l_effective_ref_img_len[i]:
574
+ ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
575
+ shift += ref_img_len
576
+ idx += 1
577
+
578
+ combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
579
+ for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
580
+ combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
581
+ combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
582
+
583
+ return combined_img_hidden_states
584
+
585
+ def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
586
+ batch_size = len(hidden_states)
587
+ p = self.config.patch_size
588
+ device = hidden_states[0].device
589
+
590
+ img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
591
+ l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
592
+
593
+ if ref_image_hidden_states is not None:
594
+ ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
595
+ l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
596
+ else:
597
+ ref_img_sizes = [None for _ in range(batch_size)]
598
+ l_effective_ref_img_len = [[0] for _ in range(batch_size)]
599
+
600
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
601
+ max_img_len = max(l_effective_img_len)
602
+
603
+ # ref image patch embeddings
604
+ flat_ref_img_hidden_states = []
605
+ for i in range(batch_size):
606
+ if ref_img_sizes[i] is not None:
607
+ imgs = []
608
+ for ref_img in ref_image_hidden_states[i]:
609
+ C, H, W = ref_img.size()
610
+ ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
611
+ imgs.append(ref_img)
612
+
613
+ img = torch.cat(imgs, dim=0)
614
+ flat_ref_img_hidden_states.append(img)
615
+ else:
616
+ flat_ref_img_hidden_states.append(None)
617
+
618
+ # image patch embeddings
619
+ flat_hidden_states = []
620
+ for i in range(batch_size):
621
+ img = hidden_states[i]
622
+ C, H, W = img.size()
623
+
624
+ img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
625
+ flat_hidden_states.append(img)
626
+
627
+ padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
628
+ padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
629
+ for i in range(batch_size):
630
+ if ref_img_sizes[i] is not None:
631
+ padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
632
+ padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
633
+
634
+ padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
635
+ padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
636
+ for i in range(batch_size):
637
+ padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
638
+ padded_img_mask[i, :l_effective_img_len[i]] = True
639
+
640
+ return (
641
+ padded_hidden_states,
642
+ padded_ref_img_hidden_states,
643
+ padded_img_mask,
644
+ padded_ref_img_mask,
645
+ l_effective_ref_img_len,
646
+ l_effective_img_len,
647
+ ref_img_sizes,
648
+ img_sizes,
649
+ )
650
+
651
+ def forward(
652
+ self,
653
+ hidden_states: torch.Tensor,
654
+ timestep: torch.Tensor,
655
+ encoder_hidden_states: torch.Tensor,
656
+ encoder_attention_mask: torch.Tensor,
657
+ ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
658
+ attention_kwargs: Optional[Dict[str, Any]] = None,
659
+ return_dict: bool = True,
660
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
661
+ if attention_kwargs is not None:
662
+ attention_kwargs = attention_kwargs.copy()
663
+ lora_scale = attention_kwargs.pop("scale", 1.0)
664
+ else:
665
+ lora_scale = 1.0
666
+
667
+ if USE_PEFT_BACKEND:
668
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
669
+ scale_lora_layers(self, lora_scale)
670
+ else:
671
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
672
+ logger.warning(
673
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
674
+ )
675
+
676
+ # 1. Condition, positional & patch embedding
677
+ batch_size = len(hidden_states)
678
+ is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
679
+
680
+ if is_hidden_states_tensor:
681
+ assert hidden_states.ndim == 4
682
+ hidden_states = [_hidden_states for _hidden_states in hidden_states]
683
+
684
+ device = hidden_states[0].device
685
+
686
+ temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
687
+
688
+ (
689
+ hidden_states,
690
+ ref_image_hidden_states,
691
+ img_mask,
692
+ ref_img_mask,
693
+ l_effective_ref_img_len,
694
+ l_effective_img_len,
695
+ ref_img_sizes,
696
+ img_sizes,
697
+ ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
698
+
699
+ (
700
+ context_rotary_emb,
701
+ ref_img_rotary_emb,
702
+ noise_rotary_emb,
703
+ rotary_emb,
704
+ encoder_seq_lengths,
705
+ seq_lengths,
706
+ ) = self.rope_embedder(
707
+ encoder_attention_mask,
708
+ l_effective_ref_img_len,
709
+ l_effective_img_len,
710
+ ref_img_sizes,
711
+ img_sizes,
712
+ device,
713
+ )
714
+
715
+ # 2. Context & noise refinement
716
+ for layer in self.context_refiner:
717
+ encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb)
718
+
719
+ combined_img_hidden_states = self.img_patch_embed_and_refine(
720
+ hidden_states,
721
+ ref_image_hidden_states,
722
+ img_mask,
723
+ ref_img_mask,
724
+ noise_rotary_emb,
725
+ ref_img_rotary_emb,
726
+ l_effective_ref_img_len,
727
+ l_effective_img_len,
728
+ temb,
729
+ )
730
+
731
+ # 3. Joint Transformer blocks
732
+ max_seq_len = max(seq_lengths)
733
+ use_mask = len(set(seq_lengths)) > 1
734
+
735
+ attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
736
+ joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
737
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
738
+ attention_mask[i, :seq_len] = True
739
+ joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len]
740
+ joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
741
+
742
+ hidden_states = joint_hidden_states
743
+
744
+ for layer in self.layers:
745
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
746
+ hidden_states = self._gradient_checkpointing_func(
747
+ layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb
748
+ )
749
+ else:
750
+ hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb)
751
+
752
+ # 4. Output norm & projection
753
+ hidden_states = self.norm_out(hidden_states, temb)
754
+
755
+ # 5. Unpatchify
756
+ p = self.config.patch_size
757
+ output = []
758
+ for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
759
+ height, width = img_size
760
+ output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
761
+ if is_hidden_states_tensor:
762
+ output = torch.stack(output, dim=0)
763
+
764
+ if USE_PEFT_BACKEND:
765
+ # remove `lora_scale` from each PEFT layer
766
+ unscale_lora_layers(self, lora_scale)
767
+
768
+ if not return_dict:
769
+ return (output,)
770
+ return Transformer2DModelOutput(sample=output)
star/models/pixel_decoder/transformer_lumina2_seq.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from einops import rearrange
22
+ from diffusers.models.transformers.transformer_lumina2 import *
23
+ from einops import repeat
24
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
25
+ import itertools
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
31
+ def __init__(
32
+ self,
33
+ hidden_size: int = 4096,
34
+ cap_feat_dim: int = 2048,
35
+ frequency_embedding_size: int = 256,
36
+ norm_eps: float = 1e-5,
37
+ ) -> None:
38
+ super().__init__()
39
+
40
+ self.time_proj = Timesteps(
41
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
42
+ )
43
+
44
+ self.timestep_embedder = TimestepEmbedding(
45
+ in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
46
+ )
47
+
48
+ self.caption_embedder = nn.Sequential(
49
+ RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
50
+ )
51
+
52
+ def forward(
53
+ self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
54
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
55
+ timestep_proj = self.time_proj(timestep).type_as(hidden_states)
56
+ time_embed = self.timestep_embedder(timestep_proj)
57
+ caption_embed = self.caption_embedder(encoder_hidden_states)
58
+ return time_embed, caption_embed
59
+
60
+
61
+ class Lumina2AttnProcessor2_0:
62
+ r"""
63
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
64
+ used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
65
+ """
66
+
67
+ def __init__(self):
68
+ if not hasattr(F, "scaled_dot_product_attention"):
69
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
70
+
71
+ def __call__(
72
+ self,
73
+ attn: Attention,
74
+ hidden_states: torch.Tensor,
75
+ encoder_hidden_states: torch.Tensor,
76
+ attention_mask: Optional[torch.Tensor] = None,
77
+ image_rotary_emb: Optional[torch.Tensor] = None,
78
+ base_sequence_length: Optional[int] = None,
79
+ ) -> torch.Tensor:
80
+ batch_size, sequence_length, _ = hidden_states.shape
81
+
82
+ # Get Query-Key-Value Pair
83
+ query = attn.to_q(hidden_states)
84
+ key = attn.to_k(encoder_hidden_states)
85
+ value = attn.to_v(encoder_hidden_states)
86
+
87
+ query_dim = query.shape[-1]
88
+ inner_dim = key.shape[-1]
89
+ head_dim = query_dim // attn.heads
90
+ dtype = query.dtype
91
+
92
+ # Get key-value heads
93
+ kv_heads = inner_dim // head_dim
94
+
95
+ query = query.view(batch_size, -1, attn.heads, head_dim)
96
+ key = key.view(batch_size, -1, kv_heads, head_dim)
97
+ value = value.view(batch_size, -1, kv_heads, head_dim)
98
+
99
+ # Apply Query-Key Norm if needed
100
+ if attn.norm_q is not None:
101
+ query = attn.norm_q(query)
102
+ if attn.norm_k is not None:
103
+ key = attn.norm_k(key)
104
+
105
+ # Apply RoPE if needed
106
+ if image_rotary_emb is not None:
107
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
108
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
109
+
110
+ query, key = query.to(dtype), key.to(dtype)
111
+
112
+ # Apply proportional attention if true
113
+ if base_sequence_length is not None:
114
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
115
+ else:
116
+ softmax_scale = attn.scale
117
+
118
+ # perform Grouped-qurey Attention (GQA)
119
+ n_rep = attn.heads // kv_heads
120
+ if n_rep >= 1:
121
+ key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
122
+ value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
123
+
124
+ # scaled_dot_product_attention expects attention_mask shape to be
125
+ # (batch, heads, source_length, target_length)
126
+ if attention_mask is not None:
127
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
128
+
129
+ query = query.transpose(1, 2)
130
+ key = key.transpose(1, 2)
131
+ value = value.transpose(1, 2)
132
+
133
+ hidden_states = F.scaled_dot_product_attention(
134
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
135
+ )
136
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
137
+ hidden_states = hidden_states.type_as(query)
138
+
139
+ # linear proj
140
+ hidden_states = attn.to_out[0](hidden_states)
141
+ hidden_states = attn.to_out[1](hidden_states)
142
+ return hidden_states
143
+
144
+
145
+ class Lumina2TransformerBlock(nn.Module):
146
+ def __init__(
147
+ self,
148
+ dim: int,
149
+ num_attention_heads: int,
150
+ num_kv_heads: int,
151
+ multiple_of: int,
152
+ ffn_dim_multiplier: float,
153
+ norm_eps: float,
154
+ modulation: bool = True,
155
+ ) -> None:
156
+ super().__init__()
157
+ self.head_dim = dim // num_attention_heads
158
+ self.dim = dim
159
+ self.modulation = modulation
160
+
161
+ self.attn = Attention(
162
+ query_dim=dim,
163
+ cross_attention_dim=None,
164
+ dim_head=dim // num_attention_heads,
165
+ qk_norm="rms_norm",
166
+ heads=num_attention_heads,
167
+ kv_heads=num_kv_heads,
168
+ eps=1e-5,
169
+ bias=False,
170
+ out_bias=False,
171
+ processor=Lumina2AttnProcessor2_0(),
172
+ )
173
+
174
+ self.feed_forward = LuminaFeedForward(
175
+ dim=dim,
176
+ inner_dim=4 * dim,
177
+ multiple_of=multiple_of,
178
+ ffn_dim_multiplier=ffn_dim_multiplier,
179
+ )
180
+
181
+ if modulation:
182
+ self.norm1 = LuminaRMSNormZero(
183
+ embedding_dim=dim,
184
+ norm_eps=norm_eps,
185
+ norm_elementwise_affine=True,
186
+ )
187
+ else:
188
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
189
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
190
+
191
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
192
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
193
+
194
+ def forward(
195
+ self,
196
+ hidden_states: torch.Tensor,
197
+ attention_mask: torch.Tensor,
198
+ image_rotary_emb: torch.Tensor,
199
+ temb: Optional[torch.Tensor] = None,
200
+ ) -> torch.Tensor:
201
+ if self.modulation:
202
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
203
+ attn_output = self.attn(
204
+ hidden_states=norm_hidden_states,
205
+ encoder_hidden_states=norm_hidden_states,
206
+ attention_mask=attention_mask,
207
+ image_rotary_emb=image_rotary_emb,
208
+ )
209
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
210
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
211
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
212
+ else:
213
+ norm_hidden_states = self.norm1(hidden_states)
214
+ attn_output = self.attn(
215
+ hidden_states=norm_hidden_states,
216
+ encoder_hidden_states=norm_hidden_states,
217
+ attention_mask=attention_mask,
218
+ image_rotary_emb=image_rotary_emb,
219
+ )
220
+ hidden_states = hidden_states + self.norm2(attn_output)
221
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
222
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
223
+
224
+ return hidden_states
225
+
226
+
227
+ class Lumina2RotaryPosEmbed(nn.Module):
228
+ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
229
+ super().__init__()
230
+ self.theta = theta
231
+ self.axes_dim = axes_dim
232
+ self.axes_lens = axes_lens
233
+ self.patch_size = patch_size
234
+
235
+ self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
236
+
237
+ def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
238
+ freqs_cis = []
239
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
240
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
241
+ emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype)
242
+ freqs_cis.append(emb)
243
+ return freqs_cis
244
+
245
+ def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
246
+ device = ids.device
247
+ if ids.device.type == "mps":
248
+ ids = ids.to("cpu")
249
+
250
+ result = []
251
+ for i in range(len(self.axes_dim)):
252
+ freqs = self.freqs_cis[i].to(ids.device)
253
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
254
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
255
+ return torch.cat(result, dim=-1).to(device)
256
+
257
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
258
+ batch_size, channels, height, width = hidden_states.shape
259
+ p = self.patch_size
260
+ post_patch_height, post_patch_width = height // p, width // p
261
+ image_seq_len = post_patch_height * post_patch_width
262
+ device = hidden_states.device
263
+
264
+ encoder_seq_len = attention_mask.shape[1]
265
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
266
+ seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
267
+ max_seq_len = max(seq_lengths)
268
+
269
+ # Create position IDs
270
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
271
+
272
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
273
+ # add caption position ids
274
+ position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device)
275
+ position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len
276
+
277
+ # add image position ids
278
+ row_ids = (
279
+ torch.arange(post_patch_height, dtype=torch.int32, device=device)
280
+ .view(-1, 1)
281
+ .repeat(1, post_patch_width)
282
+ .flatten()
283
+ )
284
+ col_ids = (
285
+ torch.arange(post_patch_width, dtype=torch.int32, device=device)
286
+ .view(1, -1)
287
+ .repeat(post_patch_height, 1)
288
+ .flatten()
289
+ )
290
+ position_ids[i, cap_seq_len:seq_len, 1] = row_ids
291
+ position_ids[i, cap_seq_len:seq_len, 2] = col_ids
292
+
293
+ # Get combined rotary embeddings
294
+ freqs_cis = self._get_freqs_cis(position_ids)
295
+
296
+ # create separate rotary embeddings for captions and images
297
+ cap_freqs_cis = torch.zeros(
298
+ batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
299
+ )
300
+ img_freqs_cis = torch.zeros(
301
+ batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
302
+ )
303
+
304
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
305
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
306
+ img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len]
307
+
308
+ # image patch embeddings
309
+ hidden_states = (
310
+ hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p)
311
+ .permute(0, 2, 4, 3, 5, 1)
312
+ .flatten(3)
313
+ .flatten(1, 2)
314
+ )
315
+
316
+ return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
317
+
318
+
319
+ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
320
+ r"""
321
+ Lumina2NextDiT: Diffusion model with a Transformer backbone.
322
+
323
+ Parameters:
324
+ sample_size (`int`): The width of the latent images. This is fixed during training since
325
+ it is used to learn a number of position embeddings.
326
+ patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
327
+ The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
328
+ in_channels (`int`, *optional*, defaults to 4):
329
+ The number of input channels for the model. Typically, this matches the number of channels in the input
330
+ images.
331
+ hidden_size (`int`, *optional*, defaults to 4096):
332
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
333
+ hidden representations.
334
+ num_layers (`int`, *optional*, default to 32):
335
+ The number of layers in the model. This defines the depth of the neural network.
336
+ num_attention_heads (`int`, *optional*, defaults to 32):
337
+ The number of attention heads in each attention layer. This parameter specifies how many separate attention
338
+ mechanisms are used.
339
+ num_kv_heads (`int`, *optional*, defaults to 8):
340
+ The number of key-value heads in the attention mechanism, if different from the number of attention heads.
341
+ If None, it defaults to num_attention_heads.
342
+ multiple_of (`int`, *optional*, defaults to 256):
343
+ A factor that the hidden size should be a multiple of. This can help optimize certain hardware
344
+ configurations.
345
+ ffn_dim_multiplier (`float`, *optional*):
346
+ A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
347
+ the model configuration.
348
+ norm_eps (`float`, *optional*, defaults to 1e-5):
349
+ A small value added to the denominator for numerical stability in normalization layers.
350
+ scaling_factor (`float`, *optional*, defaults to 1.0):
351
+ A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
352
+ overall scale of the model's operations.
353
+ """
354
+
355
+ _supports_gradient_checkpointing = True
356
+ _no_split_modules = ["Lumina2TransformerBlock"]
357
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
358
+
359
+ @register_to_config
360
+ def __init__(
361
+ self,
362
+ sample_size: int = 128,
363
+ patch_size: int = 2,
364
+ in_channels: int = 16,
365
+ out_channels: Optional[int] = None,
366
+ hidden_size: int = 2304,
367
+ num_layers: int = 26,
368
+ num_refiner_layers: int = 2,
369
+ num_attention_heads: int = 24,
370
+ num_kv_heads: int = 8,
371
+ multiple_of: int = 256,
372
+ ffn_dim_multiplier: Optional[float] = None,
373
+ norm_eps: float = 1e-5,
374
+ scaling_factor: float = 1.0,
375
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
376
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
377
+ cap_feat_dim: int = 1024,
378
+ ) -> None:
379
+ super().__init__()
380
+ self.out_channels = out_channels or in_channels
381
+
382
+ # 1. Positional, patch & conditional embeddings
383
+ self.rope_embedder = Lumina2RotaryPosEmbed(
384
+ theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
385
+ )
386
+
387
+ self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
388
+
389
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
390
+ hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
391
+ )
392
+
393
+ # 2. Noise and context refinement blocks
394
+ self.noise_refiner = nn.ModuleList(
395
+ [
396
+ Lumina2TransformerBlock(
397
+ hidden_size,
398
+ num_attention_heads,
399
+ num_kv_heads,
400
+ multiple_of,
401
+ ffn_dim_multiplier,
402
+ norm_eps,
403
+ modulation=True,
404
+ )
405
+ for _ in range(num_refiner_layers)
406
+ ]
407
+ )
408
+
409
+ self.context_refiner = nn.ModuleList(
410
+ [
411
+ Lumina2TransformerBlock(
412
+ hidden_size,
413
+ num_attention_heads,
414
+ num_kv_heads,
415
+ multiple_of,
416
+ ffn_dim_multiplier,
417
+ norm_eps,
418
+ modulation=False,
419
+ )
420
+ for _ in range(num_refiner_layers)
421
+ ]
422
+ )
423
+ self.ori_inp_dit = "none"
424
+ self.ori_inp_refiner = None
425
+
426
+ # 3. Transformer blocks
427
+ self.layers = nn.ModuleList(
428
+ [
429
+ Lumina2TransformerBlock(
430
+ hidden_size,
431
+ num_attention_heads,
432
+ num_kv_heads,
433
+ multiple_of,
434
+ ffn_dim_multiplier,
435
+ norm_eps,
436
+ modulation=True,
437
+ )
438
+ for _ in range(num_layers)
439
+ ]
440
+ )
441
+
442
+ # 4. Output norm & projection
443
+ self.norm_out = LuminaLayerNormContinuous(
444
+ embedding_dim=hidden_size,
445
+ conditioning_embedding_dim=min(hidden_size, 1024),
446
+ elementwise_affine=False,
447
+ eps=1e-6,
448
+ bias=True,
449
+ out_dim=patch_size * patch_size * self.out_channels,
450
+ )
451
+
452
+ self.gradient_checkpointing = False
453
+
454
+ def forward(
455
+ self,
456
+ hidden_states: torch.Tensor,
457
+ timestep: torch.Tensor,
458
+ encoder_hidden_states: torch.Tensor,
459
+ encoder_attention_mask: torch.Tensor,
460
+ attention_kwargs: Optional[Dict[str, Any]] = None,
461
+ return_dict: bool = True,
462
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
463
+ if attention_kwargs is not None:
464
+ attention_kwargs = attention_kwargs.copy()
465
+ lora_scale = attention_kwargs.pop("scale", 1.0)
466
+ else:
467
+ lora_scale = 1.0
468
+
469
+ if USE_PEFT_BACKEND:
470
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
471
+ scale_lora_layers(self, lora_scale)
472
+ else:
473
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
474
+ logger.warning(
475
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
476
+ )
477
+
478
+ # 1. Condition, positional & patch embedding
479
+ batch_size, _, height, width = hidden_states.shape
480
+
481
+ temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
482
+
483
+ (
484
+ hidden_states,
485
+ context_rotary_emb,
486
+ noise_rotary_emb,
487
+ rotary_emb,
488
+ encoder_seq_lengths,
489
+ seq_lengths,
490
+ ) = self.rope_embedder(hidden_states, encoder_attention_mask)
491
+
492
+ hidden_states = self.x_embedder(hidden_states)
493
+
494
+ # 2. Context & noise refinement
495
+ for layer in self.context_refiner:
496
+ encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb)
497
+
498
+ for layer in self.noise_refiner:
499
+ hidden_states = layer(hidden_states, None, noise_rotary_emb, temb)
500
+
501
+ if self.ori_inp_dit!="none" and self.ori_inp_refiner is not None:
502
+ single_img_length = hidden_states.shape[1]//2
503
+ initial_part = hidden_states[:, :single_img_length]
504
+ refined_part = self.ori_inp_refiner(hidden_states[:, single_img_length:])
505
+ updated_hidden_states = torch.cat((initial_part, refined_part), dim=1)
506
+ hidden_states = updated_hidden_states
507
+
508
+ # 3. Joint Transformer blocks
509
+ max_seq_len = max(seq_lengths)
510
+ use_mask = len(set(seq_lengths)) > 1
511
+
512
+ attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
513
+ joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
514
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
515
+ attention_mask[i, :seq_len] = True
516
+ joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len]
517
+ joint_hidden_states[i, encoder_seq_len:seq_len] = hidden_states[i]
518
+
519
+ hidden_states = joint_hidden_states
520
+
521
+ for layer in self.layers:
522
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
523
+ hidden_states = self._gradient_checkpointing_func(
524
+ layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb
525
+ )
526
+ else:
527
+ hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb)
528
+
529
+ # 4. Output norm & projection
530
+ hidden_states = self.norm_out(hidden_states, temb)
531
+
532
+ # 5. Unpatchify
533
+ p = self.config.patch_size
534
+ output = []
535
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
536
+ output.append(
537
+ hidden_states[i][encoder_seq_len:seq_len]
538
+ .view(height // p, width // p, p, p, self.out_channels)
539
+ .permute(4, 0, 2, 1, 3)
540
+ .flatten(3, 4)
541
+ .flatten(1, 2)
542
+ )
543
+ output = torch.stack(output, dim=0)
544
+
545
+ if USE_PEFT_BACKEND:
546
+ # remove `lora_scale` from each PEFT layer
547
+ unscale_lora_layers(self, lora_scale)
548
+
549
+ if not return_dict:
550
+ return (output,)
551
+ return Transformer2DModelOutput(sample=output)
star/models/pixel_encoder/vq_model.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from dataclasses import dataclass, field
3
+ from typing import List
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from deepspeed.utils import logger
9
+
10
+ @dataclass
11
+ class ModelArgs:
12
+ codebook_size: int = 16384
13
+ codebook_embed_dim: int = 8
14
+ codebook_l2_norm: bool = False
15
+ codebook_show_usage: bool = True
16
+ commit_loss_beta: float = 0.25
17
+ entropy_loss_ratio: float = 0.0
18
+ encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
19
+ decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
20
+ z_channels: int = 256
21
+ dropout_p: float = 0.0
22
+ num_res_blocks: int = 2
23
+ ch: int=128
24
+ attn_num_heads: int = 1
25
+
26
+
27
+ class VQModel(nn.Module):
28
+ def __init__(self, config: ModelArgs):
29
+ super().__init__()
30
+ self.config = config
31
+ self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p, num_res_blocks=config.num_res_blocks, ch=config.ch, attn_num_heads=config.attn_num_heads)
32
+ self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p, num_res_blocks=config.num_res_blocks, ch=config.ch, attn_num_heads=config.attn_num_heads)
33
+
34
+ self.quantize = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
35
+ config.commit_loss_beta, config.entropy_loss_ratio,
36
+ config.codebook_l2_norm, config.codebook_show_usage)
37
+ self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
38
+ self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1)
39
+
40
+
41
+ def encode(self, x):
42
+ h = self.encoder(x)
43
+ h = self.quant_conv(h)
44
+
45
+ quant, emb_loss, info = self.quantize(h)
46
+ return quant, emb_loss, info
47
+
48
+ def decode(self, quant):
49
+ quant = self.post_quant_conv(quant)
50
+ dec = self.decoder(quant)
51
+ return dec
52
+
53
+ def decode_code(self, code_b, shape=None, channel_first=True):
54
+ quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
55
+ dec = self.decode(quant_b)
56
+ return dec # [B, C, H, W]
57
+
58
+ def forward(self, input):
59
+ quant, diff, _ = self.encode(input)
60
+ dec = self.decode(quant)
61
+ return dec, diff
62
+
63
+ def get_codebook_entry(self, code_b, shape=None, channel_first=True):
64
+ quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
65
+ return quant_b
66
+
67
+ def image_to_seq(self, image):
68
+ quant, _, [_, _, indices] = self.encode(image)
69
+ batch_size = image.shape[0]
70
+ return indices.reshape(batch_size, -1)
71
+
72
+ def seq_to_image(self, tokens):
73
+ tokens = torch.clamp(tokens, min=0)
74
+ assert tokens.size(-1) == self.config.num_tokens, (
75
+ f"can not generate the image as the token length is {tokens.size(-1)} != {self.config.num_tokens}"
76
+ )
77
+ bs, HW = tokens.shape
78
+ H = W = int(math.sqrt(HW))
79
+ images = self.decode_code(tokens, shape=[bs, self.config.codebook_embed_dim, H, W])
80
+ images = torch.clip((images+1)/2, 0, 1)
81
+ images = torch.permute(images, [0, 2, 3, 1])
82
+
83
+ return images
84
+
85
+ def load_trained_weights(self, pretrained=None):
86
+ device_index = torch.cuda.current_device()
87
+ device = torch.device(f'cuda:{device_index}')
88
+ weights = torch.load(pretrained, map_location=device)
89
+ self.load_state_dict(weights, strict=True)
90
+
91
+
92
+ class Encoder(nn.Module):
93
+ def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2,
94
+ norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256, attn_num_heads=1):
95
+ super().__init__()
96
+ self.num_resolutions = len(ch_mult)
97
+ self.num_res_blocks = num_res_blocks
98
+ self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
99
+
100
+ # downsampling
101
+ in_ch_mult = (1,) + tuple(ch_mult)
102
+ self.conv_blocks = nn.ModuleList()
103
+ for i_level in range(self.num_resolutions):
104
+ conv_block = nn.Module()
105
+ # res & attn
106
+ res_block = nn.ModuleList()
107
+ attn_block = nn.ModuleList()
108
+ block_in = ch*in_ch_mult[i_level]
109
+ block_out = ch*ch_mult[i_level]
110
+ for _ in range(self.num_res_blocks):
111
+ res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
112
+ block_in = block_out
113
+ if i_level == self.num_resolutions - 1:
114
+ attn_block.append(AttnBlock(block_in, norm_type, attn_num_heads))
115
+ conv_block.res = res_block
116
+ conv_block.attn = attn_block
117
+ # downsample
118
+ if i_level != self.num_resolutions-1:
119
+ conv_block.downsample = Downsample(block_in, resamp_with_conv)
120
+ self.conv_blocks.append(conv_block)
121
+
122
+ # middle
123
+ self.mid = nn.ModuleList()
124
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
125
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type, num_heads=attn_num_heads))
126
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
127
+
128
+ # end
129
+ self.norm_out = Normalize(block_in, norm_type)
130
+ self.conv_out = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
131
+
132
+
133
+ def forward(self, x):
134
+ h = self.conv_in(x)
135
+ # downsampling
136
+ for i_level, block in enumerate(self.conv_blocks):
137
+ for i_block in range(self.num_res_blocks):
138
+ h = block.res[i_block](h)
139
+ if len(block.attn) > 0:
140
+ h = block.attn[i_block](h)
141
+ if i_level != self.num_resolutions - 1:
142
+ h = block.downsample(h)
143
+
144
+ # middle
145
+ for mid_block in self.mid:
146
+ h = mid_block(h)
147
+
148
+ # end
149
+ h = self.norm_out(h)
150
+ h = nonlinearity(h)
151
+ h = self.conv_out(h)
152
+ return h
153
+
154
+
155
+
156
+ class Decoder(nn.Module):
157
+ def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group",
158
+ dropout=0.0, resamp_with_conv=True, out_channels=3, attn_num_heads=1):
159
+ super().__init__()
160
+ self.num_resolutions = len(ch_mult)
161
+ self.num_res_blocks = num_res_blocks
162
+
163
+ block_in = ch*ch_mult[self.num_resolutions-1]
164
+ # z to block_in
165
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
166
+
167
+ # middle
168
+ self.mid = nn.ModuleList()
169
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
170
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type, num_heads=attn_num_heads))
171
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
172
+
173
+ # upsampling
174
+ self.conv_blocks = nn.ModuleList()
175
+ for i_level in reversed(range(self.num_resolutions)):
176
+ conv_block = nn.Module()
177
+ # res & attn
178
+ res_block = nn.ModuleList()
179
+ attn_block = nn.ModuleList()
180
+ block_out = ch*ch_mult[i_level]
181
+ for _ in range(self.num_res_blocks + 1):
182
+ res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
183
+ block_in = block_out
184
+ if i_level == self.num_resolutions - 1:
185
+ attn_block.append(AttnBlock(block_in, norm_type, attn_num_heads))
186
+ conv_block.res = res_block
187
+ conv_block.attn = attn_block
188
+ # downsample
189
+ if i_level != 0:
190
+ conv_block.upsample = Upsample(block_in, resamp_with_conv)
191
+ self.conv_blocks.append(conv_block)
192
+
193
+ # end
194
+ self.norm_out = Normalize(block_in, norm_type)
195
+ self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
196
+
197
+ @property
198
+ def last_layer(self):
199
+ return self.conv_out.weight
200
+
201
+ def forward(self, z):
202
+ # z to block_in
203
+ h = self.conv_in(z)
204
+
205
+ # middle
206
+ for mid_block in self.mid:
207
+ h = mid_block(h)
208
+
209
+ # upsampling
210
+ for i_level, block in enumerate(self.conv_blocks):
211
+ for i_block in range(self.num_res_blocks + 1):
212
+ h = block.res[i_block](h)
213
+ if len(block.attn) > 0:
214
+ h = block.attn[i_block](h)
215
+ if i_level != self.num_resolutions - 1:
216
+ h = block.upsample(h)
217
+
218
+ # end
219
+ h = self.norm_out(h)
220
+ h = nonlinearity(h)
221
+ h = self.conv_out(h)
222
+ return h
223
+
224
+
225
+ class VectorQuantizer(nn.Module):
226
+ def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage=False):
227
+ super().__init__()
228
+ self.n_e = n_e
229
+ self.e_dim = e_dim
230
+ self.beta = beta
231
+ self.entropy_loss_ratio = entropy_loss_ratio
232
+ self.l2_norm = l2_norm
233
+ self.show_usage = show_usage
234
+
235
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
236
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
237
+
238
+ if self.l2_norm:
239
+ self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
240
+ if self.show_usage:
241
+ if self.n_e < 65536:
242
+ self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
243
+ else:
244
+ self.register_buffer("codebook_used", nn.Parameter(torch.zeros(self.n_e+1)))
245
+ # self.register_buffer("codebook_used", nn.Parameter(torch.zeros(196608)))
246
+
247
+
248
+ # self.h_, self.w_ = int(self.n_e ** 0.5), int(self.n_e ** 0.5)
249
+ if int(self.n_e ** 0.5) ** 2 == self.n_e:
250
+ self.h_, self.w_ = int(self.n_e ** 0.5), int(self.n_e ** 0.5)
251
+ else:
252
+ self.h_ = int((self.n_e * 2) ** 0.5)
253
+ self.w_ = self.n_e // self.h_
254
+
255
+ def forward(self, z):
256
+ # reshape z -> (batch, height, width, channel) and flatten
257
+ z = torch.einsum('b c h w -> b h w c', z).contiguous()
258
+ z_flattened = z.view(z.shape[0], -1, self.e_dim) # [b, h*w, e_dim]
259
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
260
+
261
+ emb_weights = self.embedding.weight[None].repeat(z.shape[0], 1, 1)
262
+
263
+ if self.l2_norm:
264
+ z = F.normalize(z, p=2, dim=-1)
265
+ z_flattened = F.normalize(z_flattened, p=2, dim=-1)
266
+ embedding = F.normalize(emb_weights, p=2, dim=-1) # [b, n_e, e_dim]
267
+ else:
268
+ embedding = emb_weights
269
+
270
+ d = torch.sum(z_flattened ** 2, dim=2, keepdim=True) + \
271
+ torch.sum(embedding**2, dim=2).unsqueeze(1) - 2 * \
272
+ torch.einsum('bld,bnd->bln', z_flattened, embedding) # [n, h*w, n_e]
273
+
274
+ min_encoding_indices = torch.argmin(d, dim=2) # [n, h*w]
275
+ z_q = torch.stack([embedding[b, min_encoding_indices[b]] for b in range(z.shape[0])]) # [n, h*w, e_dim]
276
+ z_q = z_q.view(z.shape)
277
+ perplexity = None
278
+ min_encodings = None
279
+ vq_loss = None
280
+ commit_loss = None
281
+ entropy_loss = None
282
+ codebook_usage = 0
283
+
284
+ if self.show_usage and self.training:
285
+ self.codebook_used = self.codebook_used.long()
286
+ cur_len = min_encoding_indices.shape.numel()
287
+ self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()
288
+ self.codebook_used[-cur_len:] = min_encoding_indices.view(-1)
289
+ codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e
290
+
291
+
292
+ # compute loss for embedding
293
+ if self.training:
294
+ vq_loss = torch.mean((z_q - z.detach()) ** 2)
295
+ commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
296
+ entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d.view(-1, d.shape[-1]))
297
+
298
+ # preserve gradients
299
+ z_q = z + (z_q - z).detach()
300
+
301
+ # reshape back to match original input shape
302
+ z_q = torch.einsum('b h w c -> b c h w', z_q)
303
+
304
+ return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices)
305
+
306
+ def get_codebook_entry(self, indices, shape=None, channel_first=True):
307
+
308
+ if self.l2_norm:
309
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1) # [n, n_e, e_dim]
310
+ else:
311
+ embedding = self.embedding.weight
312
+
313
+ z_q = embedding[indices]
314
+
315
+ if shape is not None:
316
+ if channel_first:
317
+ z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) # [B, H, W, D]
318
+ # reshape back to match original input shape
319
+ z_q = z_q.permute(0, 3, 1, 2).contiguous() # [B, D, H, W]
320
+ else:
321
+ z_q = z_q.view(shape)
322
+ return z_q
323
+
324
+
325
+ class ResnetBlock(nn.Module):
326
+ def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group'):
327
+ super().__init__()
328
+ self.in_channels = in_channels
329
+ out_channels = in_channels if out_channels is None else out_channels
330
+ self.out_channels = out_channels
331
+ self.use_conv_shortcut = conv_shortcut
332
+
333
+ self.norm1 = Normalize(in_channels, norm_type)
334
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
335
+ self.norm2 = Normalize(out_channels, norm_type)
336
+ self.dropout = nn.Dropout(dropout)
337
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
338
+
339
+ if self.in_channels != self.out_channels:
340
+ if self.use_conv_shortcut:
341
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
342
+ else:
343
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
344
+
345
+ def forward(self, x):
346
+ h = x
347
+ h = self.norm1(h)
348
+ h = nonlinearity(h)
349
+ h = self.conv1(h)
350
+ h = self.norm2(h)
351
+ h = nonlinearity(h)
352
+ h = self.dropout(h)
353
+ h = self.conv2(h)
354
+ if self.in_channels != self.out_channels:
355
+ if self.use_conv_shortcut:
356
+ x = self.conv_shortcut(x)
357
+ else:
358
+ x = self.nin_shortcut(x)
359
+ return x+h
360
+
361
+
362
+
363
+ class AttnBlock(nn.Module):
364
+ def __init__(self, in_channels, norm_type='group', num_heads=1):
365
+ super().__init__()
366
+ self.num_heads = num_heads
367
+ assert in_channels % self.num_heads == 0
368
+
369
+ self.norm = Normalize(in_channels, norm_type)
370
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
371
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
372
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
373
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
374
+
375
+
376
+ def forward_single_head(self, x):
377
+ h_ = x
378
+ h_ = self.norm(h_)
379
+ q = self.q(h_)
380
+ k = self.k(h_)
381
+ v = self.v(h_)
382
+
383
+ # compute attention
384
+ b,c,h,w = q.shape
385
+ q = q.reshape(b,c,h*w)
386
+ q = q.permute(0,2,1) # b,hw,c
387
+ k = k.reshape(b,c,h*w) # b,c,hw
388
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
389
+ w_ = w_ * (int(c)**(-0.5))
390
+ w_ = F.softmax(w_, dim=2)
391
+
392
+ # attend to values
393
+ v = v.reshape(b,c,h*w)
394
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
395
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
396
+ h_ = h_.reshape(b,c,h,w)
397
+
398
+ h_ = self.proj_out(h_)
399
+
400
+ return x+h_
401
+
402
+ def forwar_multi_head(self, x):
403
+ h_ = x
404
+ h_ = self.norm(h_)
405
+ q = self.q(h_)
406
+ k = self.k(h_)
407
+ v = self.v(h_)
408
+
409
+ # compute attention
410
+ b, c, h, w = q.shape
411
+ q = q.reshape(b, self.num_heads, c//self.num_heads, h * w) # b, head, c, hw
412
+ q = q.permute(0, 1, 3, 2) # b, head, hw, c
413
+ k = k.reshape(b, self.num_heads, c//self.num_heads, h * w) # b, head, c, hw
414
+
415
+ # w_ = torch.bmm(q,k) # b,head,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
416
+ w_ = q @ k # b,head,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
417
+ w_ = w_ * (int(c // self.num_heads) ** (-0.5))
418
+ w_ = torch.nn.functional.softmax(w_, dim=3)
419
+
420
+ # attend to values
421
+ v = v.reshape(b, self.num_heads, c//self.num_heads, h * w) # b, head, c, hw
422
+
423
+ w_ = w_.permute(0, 1, 3, 2) # b,head,hw,hw (first hw of k, second of q)
424
+ h_ = v @ w_ # b, head,c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
425
+ h_ = h_.reshape(b, c, h, w)
426
+
427
+ h_ = self.proj_out(h_)
428
+
429
+ return x + h_
430
+
431
+ def forward(self, x):
432
+ if self.num_heads > 1:
433
+ return self.forwar_multi_head(x)
434
+ else:
435
+ return self.forward_single_head(x)
436
+
437
+
438
+ def nonlinearity(x):
439
+ # swish
440
+ return x*torch.sigmoid(x)
441
+
442
+
443
+ def Normalize(in_channels, norm_type='group'):
444
+ assert norm_type in ['group', 'batch']
445
+ if norm_type == 'group':
446
+ return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
447
+ elif norm_type == 'batch':
448
+ return nn.SyncBatchNorm(in_channels)
449
+
450
+
451
+ class Upsample(nn.Module):
452
+ def __init__(self, in_channels, with_conv):
453
+ super().__init__()
454
+ self.with_conv = with_conv
455
+ if self.with_conv:
456
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
457
+
458
+ def forward(self, x):
459
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
460
+ if self.with_conv:
461
+ x = self.conv(x)
462
+ return x
463
+
464
+
465
+ class Downsample(nn.Module):
466
+ def __init__(self, in_channels, with_conv):
467
+ super().__init__()
468
+ self.with_conv = with_conv
469
+ if self.with_conv:
470
+ # no asymmetric padding in torch conv, must do it ourselves
471
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
472
+
473
+ def forward(self, x):
474
+ if self.with_conv:
475
+ pad = (0,1,0,1)
476
+ x = F.pad(x, pad, mode="constant", value=0)
477
+ x = self.conv(x)
478
+ else:
479
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
480
+ return x
481
+
482
+
483
+ def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
484
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
485
+ flat_affinity /= temperature
486
+ probs = F.softmax(flat_affinity, dim=-1)
487
+ log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
488
+ if loss_type == "softmax":
489
+ target_probs = probs
490
+ else:
491
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
492
+ avg_probs = torch.mean(target_probs, dim=0)
493
+ avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
494
+ sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
495
+ loss = sample_entropy - avg_entropy
496
+ return loss
497
+
498
+
499
+ #################################################################################
500
+ # VQ Model Configs #
501
+ #################################################################################
502
+
503
+
504
+ def VQ_Model(config, **kwargs):
505
+ model = VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4, 8], decoder_ch_mult=[1, 2, 2, 4, 8], codebook_size=config.image_token_size, codebook_embed_dim=config.n_embed, z_channels=512, ch=256, attn_num_heads=config.num_heads, **kwargs))
506
+
507
+ pretrained = config.model_path
508
+ if pretrained:
509
+ model.load_trained_weights(pretrained)
510
+ return model
star/models/rope_2d.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import json
4
+ import random
5
+ import logging
6
+ import re
7
+ import time
8
+ import math
9
+ import ast
10
+ from dataclasses import dataclass, field
11
+ from typing import Dict, Optional, Sequence, List, Tuple
12
+ from io import BytesIO
13
+ import base64
14
+
15
+ import numpy as np
16
+ import torch
17
+ from torch.utils.data import Dataset
18
+ from PIL import Image
19
+ from decord import VideoReader
20
+ import transformers
21
+
22
+
23
+ def get_rope_index_25(
24
+ spatial_merge_size: Optional[int] = 2,
25
+ input_ids: Optional[torch.LongTensor] = None,
26
+ image_grid_thw: Optional[torch.LongTensor] = None,
27
+ video_grid_thw: Optional[torch.LongTensor] = None,
28
+ second_per_grid_ts: Optional[torch.Tensor] = None,
29
+ attention_mask: Optional[torch.Tensor] = None,
30
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
31
+ """
32
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
33
+
34
+ Explanation:
35
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
36
+
37
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
38
+ Examples:
39
+ input_ids: [T T T T T], here T is for text.
40
+ temporal position_ids: [0, 1, 2, 3, 4]
41
+ height position_ids: [0, 1, 2, 3, 4]
42
+ width position_ids: [0, 1, 2, 3, 4]
43
+
44
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
45
+ and 1D rotary position embedding for text part.
46
+ Examples:
47
+ Temporal (Time): 3 patches, representing different segments of the video in time.
48
+ Height: 2 patches, dividing each frame vertically.
49
+ Width: 2 patches, dividing each frame horizontally.
50
+ We also have some important parameters:
51
+ fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
52
+ tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
53
+ temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
54
+ interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
55
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
56
+ vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
57
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
58
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
59
+ text temporal position_ids: [101, 102, 103, 104, 105]
60
+ text height position_ids: [101, 102, 103, 104, 105]
61
+ text width position_ids: [101, 102, 103, 104, 105]
62
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
63
+
64
+ Args:
65
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
66
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
67
+ it.
68
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
69
+ The temporal, height and width of feature shape of each image in LLM.
70
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
71
+ The temporal, height and width of feature shape of each video in LLM.
72
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
73
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
74
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
75
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
76
+
77
+ - 1 for tokens that are **not masked**,
78
+ - 0 for tokens that are **masked**.
79
+
80
+ Returns:
81
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
82
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
83
+ """
84
+ image_token_id = 151655
85
+ video_token_id = 151656
86
+ vision_start_token_id = 151652
87
+ mrope_position_deltas = []
88
+ if input_ids is not None and (
89
+ image_grid_thw is not None or video_grid_thw is not None
90
+ ):
91
+ total_input_ids = input_ids
92
+ if attention_mask is None:
93
+ attention_mask = torch.ones_like(total_input_ids)
94
+ position_ids = torch.ones(
95
+ 3,
96
+ input_ids.shape[0],
97
+ input_ids.shape[1],
98
+ dtype=input_ids.dtype,
99
+ device=input_ids.device,
100
+ )
101
+ image_index, video_index = 0, 0
102
+ attention_mask = attention_mask.to(total_input_ids.device)
103
+ for i, input_ids in enumerate(total_input_ids):
104
+ input_ids = input_ids[attention_mask[i] == 1]
105
+ image_nums, video_nums = 0, 0
106
+ vision_start_indices = torch.argwhere(
107
+ input_ids == vision_start_token_id
108
+ ).squeeze(1)
109
+ vision_tokens = input_ids[vision_start_indices + 1]
110
+ image_nums = (vision_tokens == image_token_id).sum()
111
+ video_nums = (vision_tokens == video_token_id).sum()
112
+ input_tokens = input_ids.tolist()
113
+ llm_pos_ids_list: list = []
114
+ st = 0
115
+ remain_images, remain_videos = image_nums, video_nums
116
+ for _ in range(image_nums + video_nums):
117
+ if image_token_id in input_tokens and remain_images > 0:
118
+ ed_image = input_tokens.index(image_token_id, st)
119
+ else:
120
+ ed_image = len(input_tokens) + 1
121
+ if video_token_id in input_tokens and remain_videos > 0:
122
+ ed_video = input_tokens.index(video_token_id, st)
123
+ else:
124
+ ed_video = len(input_tokens) + 1
125
+ if ed_image < ed_video:
126
+ t, h, w = (
127
+ image_grid_thw[image_index][0],
128
+ image_grid_thw[image_index][1],
129
+ image_grid_thw[image_index][2],
130
+ )
131
+ second_per_grid_t = 0
132
+ image_index += 1
133
+ remain_images -= 1
134
+ ed = ed_image
135
+
136
+ else:
137
+ t, h, w = (
138
+ video_grid_thw[video_index][0],
139
+ video_grid_thw[video_index][1],
140
+ video_grid_thw[video_index][2],
141
+ )
142
+ if second_per_grid_ts is not None:
143
+ second_per_grid_t = second_per_grid_ts[video_index]
144
+ else:
145
+ second_per_grid_t = 1.0
146
+ video_index += 1
147
+ remain_videos -= 1
148
+ ed = ed_video
149
+ llm_grid_t, llm_grid_h, llm_grid_w = (
150
+ t.item(),
151
+ h.item() // spatial_merge_size,
152
+ w.item() // spatial_merge_size,
153
+ )
154
+ text_len = ed - st
155
+
156
+ st_idx = (
157
+ llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
158
+ )
159
+ llm_pos_ids_list.append(
160
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
161
+ )
162
+
163
+ range_tensor = torch.arange(llm_grid_t).view(-1, 1)
164
+ expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
165
+
166
+ time_tensor = expanded_range * second_per_grid_t * 2
167
+
168
+ time_tensor_long = time_tensor.long()
169
+ t_index = time_tensor_long.flatten()
170
+
171
+ h_index = (
172
+ torch.arange(llm_grid_h)
173
+ .view(1, -1, 1)
174
+ .expand(llm_grid_t, -1, llm_grid_w)
175
+ .flatten()
176
+ )
177
+ w_index = (
178
+ torch.arange(llm_grid_w)
179
+ .view(1, 1, -1)
180
+ .expand(llm_grid_t, llm_grid_h, -1)
181
+ .flatten()
182
+ )
183
+ llm_pos_ids_list.append(
184
+ torch.stack([t_index, h_index, w_index]) + text_len + st_idx
185
+ )
186
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
187
+
188
+ if st < len(input_tokens):
189
+ st_idx = (
190
+ llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
191
+ )
192
+ text_len = len(input_tokens) - st
193
+ llm_pos_ids_list.append(
194
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
195
+ )
196
+
197
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
198
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
199
+ position_ids.device
200
+ )
201
+ mrope_position_deltas.append(
202
+ llm_positions.max() + 1 - len(total_input_ids[i])
203
+ )
204
+ mrope_position_deltas = torch.tensor(
205
+ mrope_position_deltas, device=input_ids.device
206
+ ).unsqueeze(1)
207
+ return position_ids, mrope_position_deltas
208
+ else:
209
+ if attention_mask is not None:
210
+ position_ids = attention_mask.long().cumsum(-1) - 1
211
+ position_ids.masked_fill_(attention_mask == 0, 1)
212
+ position_ids = (
213
+ position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
214
+ )
215
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(
216
+ -1, keepdim=True
217
+ )[0]
218
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
219
+ else:
220
+ position_ids = (
221
+ torch.arange(input_ids.shape[1], device=input_ids.device)
222
+ .view(1, 1, -1)
223
+ .expand(3, input_ids.shape[0], -1)
224
+ )
225
+ mrope_position_deltas = torch.zeros(
226
+ [input_ids.shape[0], 1],
227
+ device=input_ids.device,
228
+ dtype=input_ids.dtype,
229
+ )
230
+
231
+ return position_ids, mrope_position_deltas
232
+