Upload files
Browse files- .gitattributes +2 -0
- README.md +3 -4
- app.py +763 -0
- assets/editing.png +3 -0
- assets/understand.png +3 -0
- requirements.txt +30 -0
- star/.DS_Store +0 -0
- star/configs/STAR_Qwen2.5-VL-3B.json +35 -0
- star/configs/STAR_Qwen2.5-VL-7B.json +35 -0
- star/models/adapter/projector.py +26 -0
- star/models/config.py +23 -0
- star/models/data_process_utils.py +65 -0
- star/models/model.py +587 -0
- star/models/pixel_decoder/lumina2_decoder.py +563 -0
- star/models/pixel_decoder/transformer_lumina2.py +770 -0
- star/models/pixel_decoder/transformer_lumina2_seq.py +551 -0
- star/models/pixel_encoder/vq_model.py +510 -0
- star/models/rope_2d.py +232 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/editing.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/understand.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,14 +1,13 @@
|
|
| 1 |
---
|
| 2 |
title: STAR
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
-
short_description: STAR Demo
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: STAR
|
| 3 |
+
emoji: 👁
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import spaces
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
import subprocess
|
| 12 |
+
subprocess.run(
|
| 13 |
+
"pip install flash-attn==2.7.3 --no-build-isolation",
|
| 14 |
+
shell=True
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from star.models.config import load_config_from_json, STARMultiModalConfig
|
| 18 |
+
from star.models.model import STARMultiModal
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
TEXTS = {
|
| 22 |
+
"zh": {
|
| 23 |
+
"title": "🌟 STAR 多模态演示",
|
| 24 |
+
"description": "基于STAR模型的多模态AI演示系统,支持文本生成图像、图像编辑和图像理解功能。",
|
| 25 |
+
"please_load_model": "请先加载模型!",
|
| 26 |
+
"please_upload_image": "请上传图像!",
|
| 27 |
+
"generation_failed": "生成失败!",
|
| 28 |
+
"generation_success_diffusion": "生成成功!",
|
| 29 |
+
"generation_success_vq": "生成成功!",
|
| 30 |
+
"edit_failed": "编辑失败!",
|
| 31 |
+
"edit_success_diffusion": "编辑成功!",
|
| 32 |
+
"edit_success_vq": "编辑成功!",
|
| 33 |
+
"understanding_failed": "理解失败!",
|
| 34 |
+
"generation_error": "生成过程中出错: ",
|
| 35 |
+
"edit_error": "编辑过程中出错: ",
|
| 36 |
+
"understanding_error": "理解过程中出错: ",
|
| 37 |
+
"tab_text_to_image": "🖼️ 文本生成图像",
|
| 38 |
+
"tab_image_edit": "🖌️ 图像编辑",
|
| 39 |
+
"tab_image_understanding": "📝 图像理解",
|
| 40 |
+
"text_prompt": "文本提示",
|
| 41 |
+
"text_prompt_placeholder": "A whimsical scene featuring a small elf with pointed ears and a green hat, sipping orange juice through a long straw from a disproportionately large orange. Next to the elf, a curious squirrel perches on its hind legs, while an owl with wide, observant eyes watches intently from a branch overhead. The orange's vibrant color contrasts with the muted browns and greens of the surrounding forest foliage.",
|
| 42 |
+
"advanced_params": "高级参数",
|
| 43 |
+
"cfg_scale": "CFG Scale",
|
| 44 |
+
"cfg_scale_info": "控制生成图像与文本的匹配程度",
|
| 45 |
+
"top_k": "Top-K",
|
| 46 |
+
"top_k_info": "采样时考虑的token数量",
|
| 47 |
+
"top_p": "Top-P",
|
| 48 |
+
"top_p_info": "核采样参数",
|
| 49 |
+
"generate_image": "🎨 生成图像",
|
| 50 |
+
"generated_image": "生成的图像",
|
| 51 |
+
"generation_status": "生成状态",
|
| 52 |
+
"input_image": "输入图像",
|
| 53 |
+
"edit_instruction": "编辑指令",
|
| 54 |
+
"edit_instruction_placeholder": "Remove the tiger in the water.",
|
| 55 |
+
"edit_image": "✏️ 编辑图像",
|
| 56 |
+
"edited_image": "编辑后的图像",
|
| 57 |
+
"edit_status": "编辑状态",
|
| 58 |
+
"question": "问题",
|
| 59 |
+
"question_placeholder": "Please describe the content of this image",
|
| 60 |
+
"max_generation_length": "最大生成长度",
|
| 61 |
+
"understand_image": "🔍 理解图像",
|
| 62 |
+
"understanding_result": "理解结果",
|
| 63 |
+
"usage_instructions": "使用说明",
|
| 64 |
+
"usage_step1": "1. **文本生成图像**: 输入文本描述,调整参数后点击生成",
|
| 65 |
+
"usage_step2": "2. **图像编辑**: 上传图像并输入编辑指令",
|
| 66 |
+
"usage_step3": "3. **图像理解**: 上传图像并提出问题",
|
| 67 |
+
"language": "语言 / Language"
|
| 68 |
+
},
|
| 69 |
+
"en": {
|
| 70 |
+
"title": "🌟 STAR Multi-Modal Demo",
|
| 71 |
+
"description": "A multi-modal AI demonstration system based on STAR model, supporting text-to-image generation, image editing, and image understanding.",
|
| 72 |
+
"please_load_model": "Please load the model first!",
|
| 73 |
+
"please_upload_image": "Please upload an image!",
|
| 74 |
+
"generation_failed": "Generation failed!",
|
| 75 |
+
"generation_success_diffusion": "Generation successful! ",
|
| 76 |
+
"generation_success_vq": "Generation successful! Using VQ decoder",
|
| 77 |
+
"edit_failed": "Editing failed!",
|
| 78 |
+
"edit_success_diffusion": "Editing successful! ",
|
| 79 |
+
"edit_success_vq": "Editing successful! Using VQ decoder",
|
| 80 |
+
"understanding_failed": "Understanding failed!",
|
| 81 |
+
"generation_error": "Error during generation: ",
|
| 82 |
+
"edit_error": "Error during editing: ",
|
| 83 |
+
"understanding_error": "Error during understanding: ",
|
| 84 |
+
"tab_text_to_image": "🖼️ Text to Image",
|
| 85 |
+
"tab_image_edit": "🖌️ Image Editing",
|
| 86 |
+
"tab_image_understanding": "📝 Image Understanding",
|
| 87 |
+
"text_prompt": "Text Prompt",
|
| 88 |
+
"text_prompt_placeholder": "A whimsical scene featuring a small elf with pointed ears and a green hat, sipping orange juice through a long straw from a disproportionately large orange. Next to the elf, a curious squirrel perches on its hind legs, while an owl with wide, observant eyes watches intently from a branch overhead. The orange's vibrant color contrasts with the muted browns and greens of the surrounding forest foliage.",
|
| 89 |
+
"advanced_params": "Advanced Parameters",
|
| 90 |
+
"cfg_scale": "CFG Scale",
|
| 91 |
+
"cfg_scale_info": "Controls how closely the generated image matches the text",
|
| 92 |
+
"top_k": "Top-K",
|
| 93 |
+
"top_k_info": "Number of tokens to consider during sampling",
|
| 94 |
+
"top_p": "Top-P",
|
| 95 |
+
"top_p_info": "Nucleus sampling parameter",
|
| 96 |
+
"generate_image": "🎨 Generate Image",
|
| 97 |
+
"generated_image": "Generated Image",
|
| 98 |
+
"generation_status": "Generation Status",
|
| 99 |
+
"input_image": "Input Image",
|
| 100 |
+
"edit_instruction": "Edit Instruction",
|
| 101 |
+
"edit_instruction_placeholder": "Remove the tiger in the water.",
|
| 102 |
+
"edit_image": "✏️ Edit Image",
|
| 103 |
+
"edited_image": "Edited Image",
|
| 104 |
+
"edit_status": "Edit Status",
|
| 105 |
+
"question": "Question",
|
| 106 |
+
"question_placeholder": "Please describe the content of this image",
|
| 107 |
+
"max_generation_length": "Max Generation Length",
|
| 108 |
+
"understand_image": "🔍 Understand Image",
|
| 109 |
+
"understanding_result": "Understanding Result",
|
| 110 |
+
"usage_instructions": "Usage Instructions",
|
| 111 |
+
"usage_step1": "1. **Text to Image**: Enter text description, adjust parameters and click generate",
|
| 112 |
+
"usage_step2": "2. **Image Editing**: Upload an image and enter editing instructions",
|
| 113 |
+
"usage_step3": "3. **Image Understanding**: Upload an image and ask questions",
|
| 114 |
+
"language": "语言 / Language"
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
class MockArgs:
|
| 119 |
+
def __init__(self):
|
| 120 |
+
self.data_type = "generation"
|
| 121 |
+
self.diffusion_as_decoder = True
|
| 122 |
+
self.ori_inp_dit = "seq"
|
| 123 |
+
self.grad_ckpt = False
|
| 124 |
+
self.diffusion_resolution = 1024
|
| 125 |
+
self.max_diff_seq_length = 256
|
| 126 |
+
self.max_seq_length = 8192
|
| 127 |
+
self.max_text_tokens = 512
|
| 128 |
+
self.max_pixels = 28 * 28 * 576
|
| 129 |
+
self.min_pixels = 28 * 28 * 16
|
| 130 |
+
self.vq_image_size = 384
|
| 131 |
+
self.vq_tokens = 576
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def set_seed(seed=100):
|
| 135 |
+
if seed > 0:
|
| 136 |
+
random.seed(seed)
|
| 137 |
+
np.random.seed(seed)
|
| 138 |
+
torch.manual_seed(seed)
|
| 139 |
+
if torch.cuda.is_available():
|
| 140 |
+
torch.cuda.manual_seed(seed)
|
| 141 |
+
torch.cuda.manual_seed_all(seed)
|
| 142 |
+
torch.backends.cudnn.deterministic = True
|
| 143 |
+
torch.backends.cudnn.benchmark = False
|
| 144 |
+
return seed
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def print_with_time(msg):
|
| 148 |
+
print(f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}: {msg}")
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class STARInferencer:
|
| 152 |
+
|
| 153 |
+
def __init__(self, model_config_path, checkpoint_path, vq_checkpoint, device="cpu"):
|
| 154 |
+
self.device = device
|
| 155 |
+
self.model_config_path = model_config_path
|
| 156 |
+
self.checkpoint_path = checkpoint_path
|
| 157 |
+
self.vq_checkpint_path = vq_checkpoint
|
| 158 |
+
self.model = None
|
| 159 |
+
self._load_model()
|
| 160 |
+
|
| 161 |
+
def _create_mock_args(self):
|
| 162 |
+
|
| 163 |
+
return MockArgs()
|
| 164 |
+
|
| 165 |
+
def _load_model(self):
|
| 166 |
+
try:
|
| 167 |
+
print_with_time("Loading model configuration...")
|
| 168 |
+
config_data = load_config_from_json(self.model_config_path)
|
| 169 |
+
model_config = STARMultiModalConfig(**config_data)
|
| 170 |
+
|
| 171 |
+
model_config.language_model.model_path = "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 172 |
+
model_config.pixel_encoder.model_path = self.vq_checkpint_path
|
| 173 |
+
model_config.pixel_decoder.model_path = "Alpha-VLLM/Lumina-Image-2.0"
|
| 174 |
+
|
| 175 |
+
args = self._create_mock_args()
|
| 176 |
+
|
| 177 |
+
print_with_time("Initializing model...")
|
| 178 |
+
self.model = STARMultiModal(model_config, args)
|
| 179 |
+
|
| 180 |
+
if os.path.exists(self.checkpoint_path):
|
| 181 |
+
print_with_time(f"Loading checkpoint from {self.checkpoint_path}")
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
checkpoint = torch.load(self.checkpoint_path, map_location='cpu', weights_only=False)
|
| 184 |
+
if 'state_dict' in checkpoint:
|
| 185 |
+
state_dict = checkpoint['state_dict']
|
| 186 |
+
else:
|
| 187 |
+
state_dict = checkpoint
|
| 188 |
+
|
| 189 |
+
if not isinstance(state_dict, dict):
|
| 190 |
+
raise ValueError("Invalid checkpoint format")
|
| 191 |
+
|
| 192 |
+
print_with_time(f"Checkpoint contains {len(state_dict)} parameters")
|
| 193 |
+
self.model.load_state_dict(state_dict, strict=False)
|
| 194 |
+
|
| 195 |
+
print_with_time(f"Moving model to device: {self.device}")
|
| 196 |
+
self.model.to(self.device)
|
| 197 |
+
|
| 198 |
+
print_with_time("Setting model to eval mode...")
|
| 199 |
+
self.model.eval()
|
| 200 |
+
|
| 201 |
+
if torch.cuda.is_available():
|
| 202 |
+
print_with_time(f"GPU memory after model loading: {torch.cuda.memory_allocated()/1024**3:.2f}GB")
|
| 203 |
+
|
| 204 |
+
print_with_time("Model loaded successfully!")
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
print_with_time(f"Error loading model: {str(e)}")
|
| 208 |
+
import traceback
|
| 209 |
+
traceback.print_exc()
|
| 210 |
+
raise e
|
| 211 |
+
|
| 212 |
+
@spaces.GPU(duration=210)
|
| 213 |
+
def generate_image(self, prompt, num_images=1, cfg=20.0, topk=2000, topp=1.0, seed=0):
|
| 214 |
+
|
| 215 |
+
if self.model.device.type == 'cpu':
|
| 216 |
+
print_with_time("Moving model to GPU...")
|
| 217 |
+
self.model.to('cuda')
|
| 218 |
+
self.model.to(torch.bfloat16)
|
| 219 |
+
print_with_time("Model moved to GPU")
|
| 220 |
+
|
| 221 |
+
set_seed(seed)
|
| 222 |
+
|
| 223 |
+
print_with_time(f"Generating image for prompt: {prompt}")
|
| 224 |
+
|
| 225 |
+
cfg = max(1.0, min(20.0, float(cfg)))
|
| 226 |
+
topk = max(100, min(2000, int(topk)))
|
| 227 |
+
topp = max(0.1, min(1.0, float(topp)))
|
| 228 |
+
|
| 229 |
+
print_with_time(f"Using validated params: cfg={cfg}, topk={topk}, topp={topp}")
|
| 230 |
+
|
| 231 |
+
if not (torch.isfinite(torch.tensor(cfg)) and torch.isfinite(torch.tensor(topk)) and torch.isfinite(torch.tensor(topp))):
|
| 232 |
+
print_with_time("Warning: Non-finite parameters detected")
|
| 233 |
+
return None
|
| 234 |
+
|
| 235 |
+
try:
|
| 236 |
+
with torch.no_grad():
|
| 237 |
+
if torch.cuda.is_available():
|
| 238 |
+
torch.cuda.empty_cache()
|
| 239 |
+
print_with_time(f"GPU memory before generation: {torch.cuda.memory_allocated()/1024**3:.2f}GB")
|
| 240 |
+
|
| 241 |
+
if not isinstance(prompt, str) or len(prompt.strip()) == 0:
|
| 242 |
+
print_with_time("Warning: Invalid prompt")
|
| 243 |
+
return None
|
| 244 |
+
|
| 245 |
+
if not (0 < cfg <= 20 and 0 < topk <= 5000 and 0 < topp <= 1):
|
| 246 |
+
print_with_time(f"Warning: Invalid parameters - cfg={cfg}, topk={topk}, topp={topp}")
|
| 247 |
+
return None
|
| 248 |
+
|
| 249 |
+
print_with_time("Calling model.generate_images...")
|
| 250 |
+
|
| 251 |
+
safe_max_tokens = 576
|
| 252 |
+
|
| 253 |
+
output = self.model.generate_images(
|
| 254 |
+
prompt,
|
| 255 |
+
max_new_tokens=safe_max_tokens,
|
| 256 |
+
num_return_sequences=num_images,
|
| 257 |
+
cfg_weight=cfg,
|
| 258 |
+
topk_sample=topk,
|
| 259 |
+
topp_sample=topp,
|
| 260 |
+
reasoning=False,
|
| 261 |
+
return_dict=True
|
| 262 |
+
)
|
| 263 |
+
print_with_time("Model generation completed")
|
| 264 |
+
|
| 265 |
+
if output is None:
|
| 266 |
+
print_with_time("Warning: Model returned None output")
|
| 267 |
+
return None
|
| 268 |
+
|
| 269 |
+
print_with_time("Processing output images...")
|
| 270 |
+
result = self._process_output_images(output, num_images)
|
| 271 |
+
print_with_time("Image processing completed")
|
| 272 |
+
return result
|
| 273 |
+
except Exception as e:
|
| 274 |
+
print_with_time(f"Error during image generation: {str(e)}")
|
| 275 |
+
import traceback
|
| 276 |
+
traceback.print_exc()
|
| 277 |
+
if torch.cuda.is_available():
|
| 278 |
+
torch.cuda.empty_cache()
|
| 279 |
+
raise e
|
| 280 |
+
|
| 281 |
+
@spaces.GPU(duration=210)
|
| 282 |
+
def edit_image(self, image, instruction, num_images=1, cfg=20.0, topk=2000, topp=1.0, seed=0):
|
| 283 |
+
|
| 284 |
+
if self.model.device.type == 'cpu':
|
| 285 |
+
print_with_time("Moving model to GPU...")
|
| 286 |
+
self.model.to('cuda')
|
| 287 |
+
self.model.to(torch.bfloat16)
|
| 288 |
+
print_with_time("Model moved to GPU")
|
| 289 |
+
|
| 290 |
+
set_seed(seed)
|
| 291 |
+
|
| 292 |
+
if isinstance(image, np.ndarray):
|
| 293 |
+
image = Image.fromarray(image)
|
| 294 |
+
|
| 295 |
+
print_with_time(f"Editing image with instruction: {instruction}")
|
| 296 |
+
|
| 297 |
+
with torch.no_grad():
|
| 298 |
+
output = self.model.generate_images_edit(
|
| 299 |
+
[image],
|
| 300 |
+
instruction,
|
| 301 |
+
max_new_tokens=576,
|
| 302 |
+
num_return_sequences=num_images,
|
| 303 |
+
cfg_weight=cfg,
|
| 304 |
+
topk_sample=topk,
|
| 305 |
+
topp_sample=topp,
|
| 306 |
+
return_dict=True
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if output is None:
|
| 310 |
+
return None
|
| 311 |
+
|
| 312 |
+
return self._process_output_images(output, num_images)
|
| 313 |
+
|
| 314 |
+
@spaces.GPU(duration=180)
|
| 315 |
+
def understand_image(self, image, question, max_new_tokens=256):
|
| 316 |
+
|
| 317 |
+
if self.model.device.type == 'cpu':
|
| 318 |
+
print_with_time("Moving model to GPU...")
|
| 319 |
+
self.model.to('cuda')
|
| 320 |
+
self.model.to(torch.bfloat16)
|
| 321 |
+
print_with_time("Model moved to GPU")
|
| 322 |
+
|
| 323 |
+
if isinstance(image, np.ndarray):
|
| 324 |
+
image = Image.fromarray(image)
|
| 325 |
+
|
| 326 |
+
print_with_time(f"Understanding image with question: {question}")
|
| 327 |
+
|
| 328 |
+
with torch.no_grad():
|
| 329 |
+
answer = self.model.inference_understand(
|
| 330 |
+
image=image,
|
| 331 |
+
question=question,
|
| 332 |
+
max_new_tokens=max_new_tokens
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
return answer
|
| 336 |
+
|
| 337 |
+
def _process_output_images(self, output, num_images):
|
| 338 |
+
image_size = 384
|
| 339 |
+
|
| 340 |
+
try:
|
| 341 |
+
if isinstance(output, dict):
|
| 342 |
+
output_images = output.get("output_images")
|
| 343 |
+
diff_images = output.get("diff_images")
|
| 344 |
+
|
| 345 |
+
results = {}
|
| 346 |
+
|
| 347 |
+
if output_images is not None:
|
| 348 |
+
if isinstance(output_images, torch.Tensor):
|
| 349 |
+
output_images = output_images.detach().cpu().numpy()
|
| 350 |
+
|
| 351 |
+
if output_images.size == 0:
|
| 352 |
+
print_with_time("Warning: Empty output_images array")
|
| 353 |
+
results["vq_images"] = None
|
| 354 |
+
else:
|
| 355 |
+
output_images = np.nan_to_num(output_images, nan=0.0, posinf=1.0, neginf=-1.0)
|
| 356 |
+
dec_vq = np.clip((output_images + 1) / 2 * 255, 0, 255)
|
| 357 |
+
|
| 358 |
+
if len(dec_vq.shape) == 3:
|
| 359 |
+
dec_vq = dec_vq.reshape(num_images, image_size, image_size, 3)
|
| 360 |
+
|
| 361 |
+
visual_img_vq = np.zeros((num_images, image_size, image_size, 3), dtype=np.uint8)
|
| 362 |
+
visual_img_vq[:, :, :] = dec_vq
|
| 363 |
+
imgs_vq = [Image.fromarray(visual_img_vq[j].astype(np.uint8)) for j in range(visual_img_vq.shape[0])]
|
| 364 |
+
results["vq_images"] = imgs_vq
|
| 365 |
+
|
| 366 |
+
if diff_images is not None:
|
| 367 |
+
results["diff_images"] = diff_images
|
| 368 |
+
else:
|
| 369 |
+
results["diff_images"] = None
|
| 370 |
+
|
| 371 |
+
return results
|
| 372 |
+
else:
|
| 373 |
+
if isinstance(output, torch.Tensor):
|
| 374 |
+
output = output.detach().cpu().numpy()
|
| 375 |
+
|
| 376 |
+
output = np.nan_to_num(output, nan=0.0, posinf=1.0, neginf=-1.0)
|
| 377 |
+
dec = np.clip((output + 1) / 2 * 255, 0, 255)
|
| 378 |
+
|
| 379 |
+
if len(dec.shape) == 3:
|
| 380 |
+
dec = dec.reshape(num_images, image_size, image_size, 3)
|
| 381 |
+
|
| 382 |
+
visual_img = np.zeros((num_images, image_size, image_size, 3), dtype=np.uint8)
|
| 383 |
+
visual_img[:, :, :] = dec
|
| 384 |
+
imgs = [Image.fromarray(visual_img[j].astype(np.uint8)) for j in range(visual_img.shape[0])]
|
| 385 |
+
return {"vq_images": imgs, "diff_images": None}
|
| 386 |
+
|
| 387 |
+
except Exception as e:
|
| 388 |
+
print_with_time(f"Error in _process_output_images: {str(e)}")
|
| 389 |
+
return {"vq_images": None, "diff_images": None}
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
inferencer = None
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def save_language_setting(language):
|
| 397 |
+
try:
|
| 398 |
+
with open('.language_setting', 'w') as f:
|
| 399 |
+
f.write(language)
|
| 400 |
+
except:
|
| 401 |
+
pass
|
| 402 |
+
|
| 403 |
+
def update_interface_language(language):
|
| 404 |
+
global current_language
|
| 405 |
+
current_language = language
|
| 406 |
+
|
| 407 |
+
save_language_setting(language)
|
| 408 |
+
|
| 409 |
+
return [
|
| 410 |
+
language,
|
| 411 |
+
f"# {get_text('title')}",
|
| 412 |
+
get_text("description"),
|
| 413 |
+
get_text("text_prompt_placeholder"),
|
| 414 |
+
get_text("edit_instruction_placeholder"),
|
| 415 |
+
get_text("question_placeholder"),
|
| 416 |
+
f"""
|
| 417 |
+
---
|
| 418 |
+
### {get_text("usage_instructions")}
|
| 419 |
+
{get_text("usage_step1")}
|
| 420 |
+
{get_text("usage_step2")}
|
| 421 |
+
{get_text("usage_step3")}
|
| 422 |
+
""",
|
| 423 |
+
f"✅ Language switched to {language.upper()} successfully! / 语言已成功切换为{language.upper()}!" # 状态消息
|
| 424 |
+
]
|
| 425 |
+
|
| 426 |
+
current_language = "en"
|
| 427 |
+
|
| 428 |
+
def get_text(key):
|
| 429 |
+
return TEXTS[current_language].get(key, key)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def auto_detect_device():
|
| 433 |
+
if torch.cuda.is_available():
|
| 434 |
+
device = f"cuda:{torch.cuda.current_device()}"
|
| 435 |
+
print_with_time(f"Detected CUDA device: {device}")
|
| 436 |
+
print_with_time(f"GPU name: {torch.cuda.get_device_name()}")
|
| 437 |
+
print_with_time(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
|
| 438 |
+
else:
|
| 439 |
+
device = "cpu"
|
| 440 |
+
print_with_time("No CUDA device detected, using CPU")
|
| 441 |
+
return device
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def initialize_model_on_startup():
|
| 445 |
+
global inferencer
|
| 446 |
+
|
| 447 |
+
default_checkpoint = hf_hub_download(
|
| 448 |
+
repo_id="MM-MVR/STAR-7B",
|
| 449 |
+
filename="STAR-7B.pt"
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
default_config = "star/configs/STAR_Qwen2.5-VL-7B.json"
|
| 453 |
+
|
| 454 |
+
vq_checkpoint = hf_hub_download(
|
| 455 |
+
repo_id="MM-MVR/STAR-VQ",
|
| 456 |
+
filename="VQ-Model.pt"
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
if not os.path.exists(default_config):
|
| 461 |
+
print_with_time(f"⚠️ Model config file not found: {default_config}")
|
| 462 |
+
return False, f"Model config file not found: {default_config}"
|
| 463 |
+
|
| 464 |
+
if not os.path.exists(default_checkpoint):
|
| 465 |
+
print_with_time(f"⚠️ Model checkpoint file not found: {default_checkpoint}")
|
| 466 |
+
return False, f"Model checkpoint file not found: {default_checkpoint}"
|
| 467 |
+
|
| 468 |
+
try:
|
| 469 |
+
device = 'cpu'
|
| 470 |
+
print_with_time("Starting to load STAR model...")
|
| 471 |
+
|
| 472 |
+
inferencer = STARInferencer(default_config, default_checkpoint, vq_checkpoint, device)
|
| 473 |
+
|
| 474 |
+
print_with_time("✅ STAR model loaded successfully!")
|
| 475 |
+
return True, "✅ STAR model loaded successfully!"
|
| 476 |
+
|
| 477 |
+
except Exception as e:
|
| 478 |
+
error_msg = f"❌ Model loading failed: {str(e)}"
|
| 479 |
+
print_with_time(error_msg)
|
| 480 |
+
return False, error_msg
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def text_to_image(prompt, cfg_scale=1.0, topk=1000, topp=0.8):
|
| 486 |
+
if inferencer is None:
|
| 487 |
+
return None, get_text("please_load_model")
|
| 488 |
+
|
| 489 |
+
cfg_scale = max(1.0, min(20.0, cfg_scale))
|
| 490 |
+
topk = max(100, min(2000, int(topk)))
|
| 491 |
+
topp = max(0.1, min(1.0, topp))
|
| 492 |
+
seed = 100
|
| 493 |
+
|
| 494 |
+
try:
|
| 495 |
+
print_with_time(f"Starting generation with params: cfg={cfg_scale}, topk={topk}, topp={topp}, seed={seed}")
|
| 496 |
+
result = inferencer.generate_image(prompt, cfg=cfg_scale, topk=topk, topp=topp, seed=seed)
|
| 497 |
+
|
| 498 |
+
if result is None:
|
| 499 |
+
return None, get_text("generation_failed")
|
| 500 |
+
|
| 501 |
+
if result.get("diff_images") and len(result["diff_images"]) > 0:
|
| 502 |
+
return result["diff_images"][0], get_text("generation_success_diffusion")
|
| 503 |
+
elif result.get("vq_images") and len(result["vq_images"]) > 0:
|
| 504 |
+
return result["vq_images"][0], get_text("generation_success_vq")
|
| 505 |
+
else:
|
| 506 |
+
return None, get_text("generation_failed")
|
| 507 |
+
|
| 508 |
+
except Exception as e:
|
| 509 |
+
return None, get_text("generation_error") + str(e)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def image_editing(image, instruction, cfg_scale=1.0, topk=1000, topp=0.8):
|
| 513 |
+
if inferencer is None:
|
| 514 |
+
return None, get_text("please_load_model")
|
| 515 |
+
|
| 516 |
+
if image is None:
|
| 517 |
+
return None, get_text("please_upload_image")
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
cfg_scale = max(1.0, min(20.0, cfg_scale))
|
| 521 |
+
topk = max(100, min(2000, int(topk)))
|
| 522 |
+
topp = max(0.1, min(1.0, topp))
|
| 523 |
+
seed = 100
|
| 524 |
+
|
| 525 |
+
try:
|
| 526 |
+
print_with_time(f"Starting image editing with params: cfg={cfg_scale}, topk={topk}, topp={topp}, seed={seed}")
|
| 527 |
+
result = inferencer.edit_image(image, instruction, cfg=cfg_scale, topk=topk, topp=topp, seed=seed)
|
| 528 |
+
|
| 529 |
+
if result is None:
|
| 530 |
+
return None, get_text("edit_failed")
|
| 531 |
+
|
| 532 |
+
if result.get("diff_images") and len(result["diff_images"]) > 0:
|
| 533 |
+
return result["diff_images"][0], get_text("edit_success_diffusion")
|
| 534 |
+
elif result.get("vq_images") and len(result["vq_images"]) > 0:
|
| 535 |
+
return result["vq_images"][0], get_text("edit_success_vq")
|
| 536 |
+
else:
|
| 537 |
+
return None, get_text("edit_failed")
|
| 538 |
+
|
| 539 |
+
except Exception as e:
|
| 540 |
+
return None, get_text("edit_error") + str(e)
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def image_understanding(image, question, max_new_tokens=256):
|
| 544 |
+
if inferencer is None:
|
| 545 |
+
return get_text("please_load_model")
|
| 546 |
+
|
| 547 |
+
if image is None:
|
| 548 |
+
return get_text("please_upload_image")
|
| 549 |
+
|
| 550 |
+
try:
|
| 551 |
+
answer = inferencer.understand_image(image, question, max_new_tokens)
|
| 552 |
+
return answer if answer else get_text("understanding_failed")
|
| 553 |
+
|
| 554 |
+
except Exception as e:
|
| 555 |
+
return get_text("understanding_error") + str(e)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def change_language(language):
|
| 559 |
+
global current_language
|
| 560 |
+
current_language = language
|
| 561 |
+
|
| 562 |
+
return (
|
| 563 |
+
get_text("title"),
|
| 564 |
+
get_text("description"),
|
| 565 |
+
get_text("tab_text_to_image"),
|
| 566 |
+
get_text("text_prompt"),
|
| 567 |
+
get_text("text_prompt_placeholder"),
|
| 568 |
+
get_text("advanced_params"),
|
| 569 |
+
get_text("cfg_scale"),
|
| 570 |
+
get_text("cfg_scale_info"),
|
| 571 |
+
get_text("top_k"),
|
| 572 |
+
get_text("top_k_info"),
|
| 573 |
+
get_text("top_p"),
|
| 574 |
+
get_text("top_p_info"),
|
| 575 |
+
get_text("random_seed"),
|
| 576 |
+
get_text("random_seed_info"),
|
| 577 |
+
get_text("generate_image"),
|
| 578 |
+
get_text("generated_image"),
|
| 579 |
+
get_text("generation_status"),
|
| 580 |
+
get_text("tab_image_edit"),
|
| 581 |
+
get_text("input_image"),
|
| 582 |
+
get_text("edit_instruction"),
|
| 583 |
+
get_text("edit_instruction_placeholder"),
|
| 584 |
+
get_text("edit_image"),
|
| 585 |
+
get_text("edited_image"),
|
| 586 |
+
get_text("edit_status"),
|
| 587 |
+
get_text("tab_image_understanding"),
|
| 588 |
+
get_text("question"),
|
| 589 |
+
get_text("question_placeholder"),
|
| 590 |
+
get_text("max_generation_length"),
|
| 591 |
+
get_text("understand_image"),
|
| 592 |
+
get_text("understanding_result"),
|
| 593 |
+
get_text("usage_instructions"),
|
| 594 |
+
get_text("usage_step1"),
|
| 595 |
+
get_text("usage_step2"),
|
| 596 |
+
get_text("usage_step3")
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def load_example_image(image_path):
|
| 601 |
+
try:
|
| 602 |
+
if os.path.exists(image_path):
|
| 603 |
+
return Image.open(image_path)
|
| 604 |
+
except Exception as e:
|
| 605 |
+
print(f"Error loading example image: {e}")
|
| 606 |
+
return None
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def create_interface():
|
| 611 |
+
|
| 612 |
+
print_with_time("Initializing STAR demo system...")
|
| 613 |
+
model_loaded, status_message = initialize_model_on_startup()
|
| 614 |
+
|
| 615 |
+
with gr.Blocks(title="🌟 STAR Multi-Modal Demo", theme=gr.themes.Soft()) as demo:
|
| 616 |
+
|
| 617 |
+
language_state = gr.State(value=current_language)
|
| 618 |
+
title_md = gr.Markdown(f"# {get_text('title')}")
|
| 619 |
+
desc_md = gr.Markdown(get_text("description"))
|
| 620 |
+
|
| 621 |
+
with gr.Row():
|
| 622 |
+
with gr.Column():
|
| 623 |
+
language_dropdown = gr.Dropdown(
|
| 624 |
+
choices=[("English", "en"), ("中文", "zh")],
|
| 625 |
+
value=current_language,
|
| 626 |
+
label="Language / 语言",
|
| 627 |
+
interactive=True
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
with gr.Tabs():
|
| 631 |
+
with gr.Tab(get_text("tab_text_to_image")) as txt_tab:
|
| 632 |
+
with gr.Row():
|
| 633 |
+
with gr.Column():
|
| 634 |
+
txt_prompt = gr.Textbox(
|
| 635 |
+
label=get_text("text_prompt"),
|
| 636 |
+
value=get_text("text_prompt_placeholder"),
|
| 637 |
+
lines=3
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
with gr.Accordion(get_text("advanced_params"), open=False):
|
| 641 |
+
txt_cfg_scale = gr.Slider(
|
| 642 |
+
minimum=1.0, maximum=20.0, value=1.1, step=0.1,
|
| 643 |
+
label=get_text("cfg_scale"), info=get_text("cfg_scale_info")
|
| 644 |
+
)
|
| 645 |
+
txt_topk = gr.Slider(
|
| 646 |
+
minimum=100, maximum=2000, value=1000, step=50,
|
| 647 |
+
label=get_text("top_k"), info=get_text("top_k_info")
|
| 648 |
+
)
|
| 649 |
+
txt_topp = gr.Slider(
|
| 650 |
+
minimum=0.1, maximum=1.0, value=0.8, step=0.05,
|
| 651 |
+
label=get_text("top_p"), info=get_text("top_p_info")
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
txt_generate_btn = gr.Button(get_text("generate_image"), variant="primary")
|
| 655 |
+
|
| 656 |
+
with gr.Column():
|
| 657 |
+
txt_output_image = gr.Image(label=get_text("generated_image"))
|
| 658 |
+
txt_status = gr.Textbox(label=get_text("generation_status"), interactive=False)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
with gr.Tab(get_text("tab_image_edit")) as edit_tab:
|
| 662 |
+
with gr.Row():
|
| 663 |
+
with gr.Column():
|
| 664 |
+
edit_input_image = gr.Image(
|
| 665 |
+
label=get_text("input_image"),
|
| 666 |
+
value=load_example_image('assets/editing.png')
|
| 667 |
+
)
|
| 668 |
+
edit_instruction = gr.Textbox(
|
| 669 |
+
label=get_text("edit_instruction"),
|
| 670 |
+
value=get_text("edit_instruction_placeholder"),
|
| 671 |
+
lines=2
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
with gr.Accordion(get_text("advanced_params"), open=False):
|
| 675 |
+
edit_cfg_scale = gr.Slider(
|
| 676 |
+
minimum=1.0, maximum=20.0, value=1.1, step=0.1,
|
| 677 |
+
label=get_text("cfg_scale")
|
| 678 |
+
)
|
| 679 |
+
edit_topk = gr.Slider(
|
| 680 |
+
minimum=100, maximum=2000, value=1000, step=50,
|
| 681 |
+
label=get_text("top_k")
|
| 682 |
+
)
|
| 683 |
+
edit_topp = gr.Slider(
|
| 684 |
+
minimum=0.1, maximum=1.0, value=0.8, step=0.05,
|
| 685 |
+
label=get_text("top_p")
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
edit_btn = gr.Button(get_text("edit_image"), variant="primary")
|
| 689 |
+
|
| 690 |
+
with gr.Column():
|
| 691 |
+
edit_output_image = gr.Image(label=get_text("edited_image"))
|
| 692 |
+
edit_status = gr.Textbox(label=get_text("edit_status"), interactive=False)
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
with gr.Tab(get_text("tab_image_understanding")) as understand_tab:
|
| 696 |
+
with gr.Row():
|
| 697 |
+
with gr.Column():
|
| 698 |
+
understand_input_image = gr.Image(
|
| 699 |
+
label=get_text("input_image"),
|
| 700 |
+
value=load_example_image('assets/understand.png')
|
| 701 |
+
)
|
| 702 |
+
understand_question = gr.Textbox(
|
| 703 |
+
label=get_text("question"),
|
| 704 |
+
value=get_text("question_placeholder"),
|
| 705 |
+
lines=2
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
with gr.Accordion(get_text("advanced_params"), open=False):
|
| 709 |
+
understand_max_tokens = gr.Slider(
|
| 710 |
+
minimum=64, maximum=1024, value=256, step=64,
|
| 711 |
+
label=get_text("max_generation_length")
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
understand_btn = gr.Button(get_text("understand_image"), variant="primary")
|
| 715 |
+
|
| 716 |
+
with gr.Column():
|
| 717 |
+
understand_output = gr.Textbox(
|
| 718 |
+
label=get_text("understanding_result"),
|
| 719 |
+
lines=15,
|
| 720 |
+
interactive=False
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
usage_md = gr.Markdown(
|
| 724 |
+
f"""
|
| 725 |
+
---
|
| 726 |
+
### {get_text("usage_instructions")}
|
| 727 |
+
{get_text("usage_step1")}
|
| 728 |
+
{get_text("usage_step2")}
|
| 729 |
+
{get_text("usage_step3")}
|
| 730 |
+
"""
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
txt_generate_btn.click(
|
| 734 |
+
fn=text_to_image,
|
| 735 |
+
inputs=[txt_prompt, txt_cfg_scale, txt_topk, txt_topp],
|
| 736 |
+
outputs=[txt_output_image, txt_status]
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
edit_btn.click(
|
| 740 |
+
fn=image_editing,
|
| 741 |
+
inputs=[edit_input_image, edit_instruction, edit_cfg_scale, edit_topk, edit_topp],
|
| 742 |
+
outputs=[edit_output_image, edit_status]
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
understand_btn.click(
|
| 746 |
+
fn=image_understanding,
|
| 747 |
+
inputs=[understand_input_image, understand_question, understand_max_tokens],
|
| 748 |
+
outputs=understand_output
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
language_dropdown.change(
|
| 753 |
+
fn=update_interface_language,
|
| 754 |
+
inputs=[language_dropdown],
|
| 755 |
+
outputs=[language_state, title_md, desc_md, txt_prompt, edit_instruction, understand_question, usage_md, txt_status]
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
return demo
|
| 759 |
+
|
| 760 |
+
demo = create_interface()
|
| 761 |
+
|
| 762 |
+
demo.launch(share=True, show_error=True)
|
| 763 |
+
|
assets/editing.png
ADDED
|
Git LFS Details
|
assets/understand.png
ADDED
|
Git LFS Details
|
requirements.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
torch>=2.2.2
|
| 3 |
+
torchvision>=0.17.2
|
| 4 |
+
torchaudio>=2.2.2
|
| 5 |
+
transformers==4.51.0
|
| 6 |
+
diffusers==0.33.0
|
| 7 |
+
decord>=0.6.0
|
| 8 |
+
attrdict2
|
| 9 |
+
accelerate>=0.32.0
|
| 10 |
+
timm>=1.0.15
|
| 11 |
+
opencv-python>=4.10.0
|
| 12 |
+
pillow>=10.4.0
|
| 13 |
+
einops>=0.8.0
|
| 14 |
+
xformers>=0.0.28
|
| 15 |
+
numpy>=1.26.0
|
| 16 |
+
pandas>=2.2.0
|
| 17 |
+
datasets>=3.0.0
|
| 18 |
+
tokenizers>=0.21.0
|
| 19 |
+
sentencepiece>=0.1.99
|
| 20 |
+
torchmetrics>=1.4.0
|
| 21 |
+
tqdm>=4.66.0
|
| 22 |
+
pyyaml>=6.0.0
|
| 23 |
+
requests>=2.32.0
|
| 24 |
+
packaging>=24.1
|
| 25 |
+
ipython>=8.26.0
|
| 26 |
+
matplotlib>=3.9.0
|
| 27 |
+
deepspeed>=0.14.4
|
| 28 |
+
wandb>=0.16.3
|
| 29 |
+
gradio>=5.34.0
|
| 30 |
+
qwen-vl-utils
|
star/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
star/configs/STAR_Qwen2.5-VL-3B.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_name": "STAR_Qwen2.5-3B_VQGAN",
|
| 3 |
+
"model_type": "STARMultiModalityConfig",
|
| 4 |
+
"language_model": {
|
| 5 |
+
"model_name": "Qwen2.5-VL",
|
| 6 |
+
"model_path": "checkpoints/Qwen2.5-VL-3B-Instruct"
|
| 7 |
+
},
|
| 8 |
+
"pixel_encoder": {
|
| 9 |
+
"model_name": "VQ_Model",
|
| 10 |
+
"model_path": "checkpoints/VQ-Model.pt",
|
| 11 |
+
"image_token_size": 65536,
|
| 12 |
+
"n_embed": 512,
|
| 13 |
+
"num_tokens": 576,
|
| 14 |
+
"num_heads": 8
|
| 15 |
+
},
|
| 16 |
+
"pixel_adapter": {
|
| 17 |
+
"model_name": "MLP_GELU",
|
| 18 |
+
"depth": 2,
|
| 19 |
+
"input_dim": 512,
|
| 20 |
+
"n_embed": 2048
|
| 21 |
+
},
|
| 22 |
+
"stacked_ar": {
|
| 23 |
+
"num_layers": 16
|
| 24 |
+
},
|
| 25 |
+
"pixel_output_head": {
|
| 26 |
+
"image_token_embed": 4096,
|
| 27 |
+
"image_token_size": 65536,
|
| 28 |
+
"n_embed": 2048
|
| 29 |
+
},
|
| 30 |
+
"pixel_decoder": {
|
| 31 |
+
"model_name": "LUMINA2",
|
| 32 |
+
"model_path": "checkpoints/lumina-image2"
|
| 33 |
+
},
|
| 34 |
+
"torch_dtype": "bfloat16"
|
| 35 |
+
}
|
star/configs/STAR_Qwen2.5-VL-7B.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_name": "STAR_Qwen2.5-7B_VQGAN",
|
| 3 |
+
"model_type": "STARMultiModalityConfig",
|
| 4 |
+
"language_model": {
|
| 5 |
+
"model_name": "Qwen2.5-VL",
|
| 6 |
+
"model_path": "checkpoints/Qwen2.5-VL-7B-Instruct"
|
| 7 |
+
},
|
| 8 |
+
"pixel_encoder": {
|
| 9 |
+
"model_name": "VQ_Model",
|
| 10 |
+
"model_path": "checkpoints/VQ-Model.pt",
|
| 11 |
+
"image_token_size": 65536,
|
| 12 |
+
"n_embed": 512,
|
| 13 |
+
"num_tokens": 576,
|
| 14 |
+
"num_heads": 8
|
| 15 |
+
},
|
| 16 |
+
"pixel_adapter": {
|
| 17 |
+
"model_name": "MLP_GELU",
|
| 18 |
+
"depth": 4,
|
| 19 |
+
"input_dim": 512,
|
| 20 |
+
"n_embed": 3584
|
| 21 |
+
},
|
| 22 |
+
"stacked_ar": {
|
| 23 |
+
"num_layers": 14
|
| 24 |
+
},
|
| 25 |
+
"pixel_output_head": {
|
| 26 |
+
"image_token_embed": 4096,
|
| 27 |
+
"image_token_size": 65536,
|
| 28 |
+
"n_embed": 3584
|
| 29 |
+
},
|
| 30 |
+
"pixel_decoder": {
|
| 31 |
+
"model_name": "LUMINA2",
|
| 32 |
+
"model_path": "checkpoints/lumina-image2"
|
| 33 |
+
},
|
| 34 |
+
"torch_dtype": "bfloat16"
|
| 35 |
+
}
|
star/models/adapter/projector.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class MlpProjector(nn.Module):
|
| 6 |
+
def __init__(self, cfg):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
self.cfg = cfg
|
| 10 |
+
|
| 11 |
+
if cfg.model_name == "MLP_GELU":
|
| 12 |
+
mlp_depth = cfg.get("depth", 1)
|
| 13 |
+
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
|
| 14 |
+
for _ in range(1, mlp_depth):
|
| 15 |
+
modules.append(nn.GELU())
|
| 16 |
+
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
| 17 |
+
modules = nn.Sequential(*modules)
|
| 18 |
+
|
| 19 |
+
else:
|
| 20 |
+
raise ValueError(f"Unknown projector type: {cfg.model_name}")
|
| 21 |
+
|
| 22 |
+
self.layers = modules
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
|
| 26 |
+
return self.layers(x)
|
star/models/config.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from attrdict2 import AttrDict
|
| 3 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def load_config_from_json(json_path):
|
| 7 |
+
with open(json_path, "r") as f:
|
| 8 |
+
config_data = json.load(f)
|
| 9 |
+
return config_data
|
| 10 |
+
|
| 11 |
+
class STARMultiModalConfig(PretrainedConfig):
|
| 12 |
+
model_type = "STARMultiModal"
|
| 13 |
+
|
| 14 |
+
def __init__(self, **kwargs):
|
| 15 |
+
super().__init__(**kwargs)
|
| 16 |
+
|
| 17 |
+
self.pixel_encoder = AttrDict(kwargs.get("pixel_encoder", {}))
|
| 18 |
+
self.pixel_adapter = AttrDict(kwargs.get("pixel_adapter", {}))
|
| 19 |
+
self.pixel_output_head = AttrDict(kwargs.get("pixel_output_head", {}))
|
| 20 |
+
self.language_model = AttrDict(kwargs.get("language_model", {}))
|
| 21 |
+
self.stacked_ar = AttrDict(kwargs.get("stacked_ar", {}))
|
| 22 |
+
self.pixel_decoder = AttrDict(kwargs.get("pixel_decoder", {}))
|
| 23 |
+
|
star/models/data_process_utils.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
|
| 7 |
+
BACKGROUND_COLOR=(127, 127, 127)
|
| 8 |
+
|
| 9 |
+
from torchvision.transforms import InterpolationMode
|
| 10 |
+
|
| 11 |
+
def preprocess_image_with_min_size(image, min_factor=28):
|
| 12 |
+
width, height = image.size
|
| 13 |
+
if height < min_factor or width < min_factor:
|
| 14 |
+
scale_factor = max(min_factor / height, min_factor / width)
|
| 15 |
+
new_width = int(width * scale_factor)
|
| 16 |
+
new_height = int(height * scale_factor)
|
| 17 |
+
|
| 18 |
+
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 19 |
+
return image
|
| 20 |
+
|
| 21 |
+
def preprocess_image_gen(images, processor, vq_transform):
|
| 22 |
+
|
| 23 |
+
image_list = []
|
| 24 |
+
grid_thw_list = []
|
| 25 |
+
vq_image_list = []
|
| 26 |
+
for image in images:
|
| 27 |
+
image = preprocess_image_with_min_size(image)
|
| 28 |
+
|
| 29 |
+
visual_processed = processor.preprocess(image, return_tensors="pt")
|
| 30 |
+
image_tensor = visual_processed["pixel_values"]
|
| 31 |
+
if isinstance(image_tensor, list):
|
| 32 |
+
image_tensor = image_tensor[0]
|
| 33 |
+
image_list.append(image_tensor)
|
| 34 |
+
|
| 35 |
+
grid_thw = visual_processed["image_grid_thw"][0]
|
| 36 |
+
grid_thw_list.append(grid_thw)
|
| 37 |
+
|
| 38 |
+
vq_image = vq_transform(image)
|
| 39 |
+
vq_image_list.append(vq_image)
|
| 40 |
+
|
| 41 |
+
image_tensor = torch.stack(image_list, dim=0)
|
| 42 |
+
grid_thw = torch.stack(grid_thw_list, dim=0)
|
| 43 |
+
vq_image = torch.stack(vq_image_list, dim=0)
|
| 44 |
+
|
| 45 |
+
return {
|
| 46 |
+
"pixel_values": image_tensor,
|
| 47 |
+
"image_grid_thw": grid_thw,
|
| 48 |
+
"vq_pixel_values": vq_image
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_vq_transform(args):
|
| 54 |
+
return transforms.Compose([
|
| 55 |
+
transforms.Resize((args.vq_image_size, args.vq_image_size), interpolation=InterpolationMode.BILINEAR),
|
| 56 |
+
transforms.ToTensor(), # [0, 255] -> [0, 1]
|
| 57 |
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # [0, 1] -> [-1, 1]
|
| 58 |
+
])
|
| 59 |
+
|
| 60 |
+
def get_full_transform(args):
|
| 61 |
+
return transforms.Compose([
|
| 62 |
+
transforms.Resize((1024, 1024), interpolation=InterpolationMode.BILINEAR),
|
| 63 |
+
transforms.ToTensor(), # [0, 255] -> [0, 1]
|
| 64 |
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # [0, 1] -> [-1, 1]
|
| 65 |
+
])
|
star/models/model.py
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import requests
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 11 |
+
from torch.nn import CrossEntropyLoss
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoConfig,
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
AutoModelForCausalLM,
|
| 16 |
+
PreTrainedModel
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, Qwen2VLProcessor
|
| 20 |
+
|
| 21 |
+
from star.models.config import STARMultiModalConfig
|
| 22 |
+
from star.models.pixel_encoder.vq_model import VQ_Model
|
| 23 |
+
from star.models.adapter.projector import MlpProjector
|
| 24 |
+
from star.models.pixel_decoder.lumina2_decoder import Lumina2Decoder
|
| 25 |
+
from star.models.data_process_utils import get_full_transform, get_vq_transform, preprocess_image_gen
|
| 26 |
+
from star.models.rope_2d import get_rope_index_25
|
| 27 |
+
|
| 28 |
+
class STARMultiModal(PreTrainedModel):
|
| 29 |
+
def __init__(self, config: STARMultiModalConfig, args=None, **kwargs):
|
| 30 |
+
super().__init__(config)
|
| 31 |
+
|
| 32 |
+
self.config = config
|
| 33 |
+
self.args = args if args is not None else kwargs.get("args", None)
|
| 34 |
+
|
| 35 |
+
# Pixel Encoder Generation
|
| 36 |
+
model_name = config.pixel_encoder.model_name
|
| 37 |
+
if model_name == "VQ_Model":
|
| 38 |
+
self.pixel_encoder = VQ_Model(config.pixel_encoder)
|
| 39 |
+
else:
|
| 40 |
+
assert None, f"Unsupported {model_name}"
|
| 41 |
+
self.pixel_encoder.eval()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Pixel Adapter Generation
|
| 45 |
+
model_name = config.pixel_adapter.model_name
|
| 46 |
+
if model_name == "MLP_GELU":
|
| 47 |
+
self.pixel_adapter = MlpProjector(config.pixel_adapter)
|
| 48 |
+
else:
|
| 49 |
+
assert None, f"Unsupported {model_name}"
|
| 50 |
+
|
| 51 |
+
# Pixel Ouput Head Generation
|
| 52 |
+
self.pixel_output_head = torch.nn.Linear(config.pixel_output_head.n_embed, config.pixel_output_head.image_token_size)
|
| 53 |
+
|
| 54 |
+
if getattr(args, "diffusion_as_decoder") and args.diffusion_as_decoder:
|
| 55 |
+
self.diffusion_decoder = Lumina2Decoder(config.pixel_decoder, args)
|
| 56 |
+
else:
|
| 57 |
+
self.diffusion_decoder = None
|
| 58 |
+
|
| 59 |
+
# Large Language Model
|
| 60 |
+
model_name, model_path = config.language_model.model_name, config.language_model.model_path
|
| 61 |
+
|
| 62 |
+
if model_name == "Qwen2.5-VL":
|
| 63 |
+
self.llm = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="cuda")
|
| 64 |
+
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
| 65 |
+
self.tokenizer = self.processor.tokenizer
|
| 66 |
+
|
| 67 |
+
self.image_processor = self.processor.image_processor
|
| 68 |
+
self.image_processor.max_pixels = self.args.max_pixels
|
| 69 |
+
self.image_processor.min_pixels = self.args.min_pixels
|
| 70 |
+
self.image_processor.size["longest_edge"] = self.args.max_pixels
|
| 71 |
+
self.image_processor.size["shortest_edge"] = self.args.min_pixels
|
| 72 |
+
|
| 73 |
+
special_token_tags = ["<|vision_start|>", "<|vision_pad|>", "<|image_pad|>", "<|vision_end|>", "<|fim_pad|>"]
|
| 74 |
+
self.special_tokens = {tag: self.tokenizer.vocab.get(tag, None) for tag in special_token_tags}
|
| 75 |
+
|
| 76 |
+
else:
|
| 77 |
+
assert None, f"unsupported {model_name}: {model_path}"
|
| 78 |
+
self.llm.generation_config.pad_token_id = self.tokenizer.encode(self.tokenizer.pad_token)[0]
|
| 79 |
+
|
| 80 |
+
if self.args.grad_ckpt:
|
| 81 |
+
self.llm.gradient_checkpointing_enable()
|
| 82 |
+
self.llm.visual.gradient_checkpointing_enable()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
stacked_ar_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
| 86 |
+
num_layers_to_extract = config.stacked_ar.num_layers
|
| 87 |
+
stacked_ar_config.num_hidden_layers = num_layers_to_extract
|
| 88 |
+
|
| 89 |
+
self.stacked_ar = Qwen2_5_VLForConditionalGeneration(stacked_ar_config)
|
| 90 |
+
|
| 91 |
+
temp_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
|
| 92 |
+
total_layers = len(temp_model.model.layers)
|
| 93 |
+
start_layer = max(0, total_layers - num_layers_to_extract)
|
| 94 |
+
temp_model.model.layers = temp_model.model.layers[start_layer:]
|
| 95 |
+
self.stacked_ar.load_state_dict(temp_model.state_dict(), strict=False)
|
| 96 |
+
|
| 97 |
+
self.stacked_ar = self.stacked_ar.to("cuda")
|
| 98 |
+
del self.stacked_ar.visual, self.stacked_ar.model.embed_tokens, self.stacked_ar.lm_head
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# For Inference Generation
|
| 102 |
+
def generate_images(self, prompt, max_new_tokens=256, num_return_sequences=1, cfg_weight=5.0, topk_sample=1000, topp_sample=1.0, temperature=1.0, reasoning=False, return_dict=False):
|
| 103 |
+
|
| 104 |
+
if reasoning:
|
| 105 |
+
return self.generate_images_reasoning(prompt, max_new_tokens, num_return_sequences, cfg_weight, topk_sample, topp_sample, temperature, return_dict)
|
| 106 |
+
|
| 107 |
+
messages = [{'role': 'user', 'content': prompt}]
|
| 108 |
+
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 109 |
+
text_token = self.tokenizer.encode(text)
|
| 110 |
+
text_token = torch.tensor(text_token).long().to(self.device)
|
| 111 |
+
|
| 112 |
+
keys = list(self.special_tokens.keys())
|
| 113 |
+
start_token = (torch.ones(1) * self.special_tokens.get(keys[0])).long().to(self.device)
|
| 114 |
+
|
| 115 |
+
input_ids = torch.cat((text_token, start_token)).long().to(self.device)
|
| 116 |
+
tokens = torch.zeros((num_return_sequences*2, len(input_ids)), dtype=torch.int).cuda()
|
| 117 |
+
assistant_tokens = input_ids[-4:]
|
| 118 |
+
|
| 119 |
+
for i in range(num_return_sequences*2):
|
| 120 |
+
tokens[i, :] = input_ids
|
| 121 |
+
if i % 2 != 0:
|
| 122 |
+
tokens[i, 1:-1] = self.special_tokens.get(keys[4])
|
| 123 |
+
tokens[i, -4:] = assistant_tokens
|
| 124 |
+
|
| 125 |
+
inputs_embeds = self.llm.model.embed_tokens(tokens).to(self.device)
|
| 126 |
+
generated_tokens = torch.zeros((num_return_sequences, max_new_tokens), dtype=torch.int).cuda()
|
| 127 |
+
|
| 128 |
+
for i in range(max_new_tokens):
|
| 129 |
+
outputs = self.llm.model(
|
| 130 |
+
inputs_embeds=inputs_embeds,
|
| 131 |
+
use_cache=True,
|
| 132 |
+
past_key_values=outputs.past_key_values if i != 0 else None,
|
| 133 |
+
output_hidden_states=True)
|
| 134 |
+
last_hidden_states = outputs[0]
|
| 135 |
+
|
| 136 |
+
output_states = self.stacked_ar.model(
|
| 137 |
+
inputs_embeds=last_hidden_states,
|
| 138 |
+
past_key_values=output_states.past_key_values if i != 0 else None,
|
| 139 |
+
output_hidden_states=True,
|
| 140 |
+
use_cache=True)
|
| 141 |
+
|
| 142 |
+
last_hidden_states = output_states.hidden_states[-1]
|
| 143 |
+
|
| 144 |
+
logits = self.pixel_output_head(last_hidden_states[:, -1, :])
|
| 145 |
+
logit_cond = logits[0::2, :]
|
| 146 |
+
logit_uncond = logits[1::2, :]
|
| 147 |
+
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
|
| 148 |
+
next_token, _ = self.sample(logits, temperature=1.0, top_k=topk_sample, top_p=topp_sample)
|
| 149 |
+
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
| 150 |
+
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
| 151 |
+
|
| 152 |
+
vqgan_embeds = self.pixel_encoder.get_codebook_entry(next_token)
|
| 153 |
+
img_embeds = self.pixel_adapter(vqgan_embeds)
|
| 154 |
+
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
| 155 |
+
|
| 156 |
+
latent_size = int(math.sqrt(max_new_tokens))
|
| 157 |
+
output_images = self.pixel_encoder.decode_code(generated_tokens.to(dtype=torch.int), shape=[num_return_sequences, self.pixel_encoder.config.codebook_embed_dim, latent_size, latent_size])
|
| 158 |
+
output_images = output_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
| 159 |
+
|
| 160 |
+
diff_images = None
|
| 161 |
+
if self.diffusion_decoder is not None:
|
| 162 |
+
gen_image_embeds = self.pixel_encoder.get_codebook_entry(generated_tokens)
|
| 163 |
+
|
| 164 |
+
if self.args.diffusion_resolution==512:
|
| 165 |
+
self.diffusion_decoder.pipe.transformer.config.sample_size=16
|
| 166 |
+
elif self.args.diffusion_resolution==1024:
|
| 167 |
+
self.diffusion_decoder.pipe.transformer.config.sample_size=32
|
| 168 |
+
diff_images = self.diffusion_decoder.pipe(
|
| 169 |
+
prompt,
|
| 170 |
+
num_inference_steps=40,
|
| 171 |
+
guidance_scale=4.5,
|
| 172 |
+
gen_image_embeds=gen_image_embeds, #gen_image_embeds,
|
| 173 |
+
control_emd="text",
|
| 174 |
+
ori_inp_way=self.diffusion_decoder.transformer.ori_inp_dit,
|
| 175 |
+
only_t2i="vqconcat",
|
| 176 |
+
img_guidance_scale=1.05,
|
| 177 |
+
height=self.args.diffusion_resolution,
|
| 178 |
+
width=self.args.diffusion_resolution
|
| 179 |
+
).images
|
| 180 |
+
if return_dict:
|
| 181 |
+
return {"output_images": output_images, "generated_tokens": generated_tokens, "diff_images": diff_images}
|
| 182 |
+
return output_images
|
| 183 |
+
|
| 184 |
+
def answer_text_qwen_vl(self, question, max_new_tokens=256, do_sample=True):
|
| 185 |
+
|
| 186 |
+
messages = [
|
| 187 |
+
{
|
| 188 |
+
"role": "user",
|
| 189 |
+
"content": [
|
| 190 |
+
{"type": "text", "text": question},
|
| 191 |
+
],
|
| 192 |
+
}
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
# Preparation for inference
|
| 196 |
+
text = self.processor.apply_chat_template(
|
| 197 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 198 |
+
)
|
| 199 |
+
# image_inputs, video_inputs = process_vision_info(messages)
|
| 200 |
+
inputs = self.processor(
|
| 201 |
+
text=[text],
|
| 202 |
+
images=None,
|
| 203 |
+
videos=None,
|
| 204 |
+
padding=True,
|
| 205 |
+
return_tensors="pt",
|
| 206 |
+
)
|
| 207 |
+
inputs = inputs.to(self.llm.device)
|
| 208 |
+
|
| 209 |
+
# Inference: Generation of the output
|
| 210 |
+
generated_ids = self.llm.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=do_sample)
|
| 211 |
+
generated_ids_trimmed = [
|
| 212 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 213 |
+
]
|
| 214 |
+
output_text = self.processor.batch_decode(
|
| 215 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
return output_text[0] if output_text else ""
|
| 219 |
+
|
| 220 |
+
def generate_images_reasoning(self, prompt, max_new_tokens=256, num_return_sequences=1, cfg_weight=5.0, topk_sample=1000, topp_sample=1.0, temperature=1.0, return_dict=False):
|
| 221 |
+
|
| 222 |
+
messages = [{'role': 'user', 'content': prompt}]
|
| 223 |
+
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 224 |
+
text_token = self.tokenizer.encode(text)
|
| 225 |
+
text_token = torch.tensor(text_token).long().to(self.device)
|
| 226 |
+
|
| 227 |
+
keys = list(self.special_tokens.keys())
|
| 228 |
+
start_token = (torch.ones(1) * self.special_tokens.get(keys[0])).long().to(self.device)
|
| 229 |
+
|
| 230 |
+
input_ids = torch.cat((text_token, start_token)).long().to(self.device)
|
| 231 |
+
tokens = torch.zeros((num_return_sequences*2, len(input_ids)), dtype=torch.int).cuda()
|
| 232 |
+
assistant_tokens = input_ids[-4:]
|
| 233 |
+
|
| 234 |
+
for i in range(num_return_sequences*2):
|
| 235 |
+
tokens[i, :] = input_ids
|
| 236 |
+
if i % 2 != 0:
|
| 237 |
+
tokens[i, 1:-1] = self.special_tokens.get(keys[4])
|
| 238 |
+
tokens[i, -4:] = assistant_tokens
|
| 239 |
+
|
| 240 |
+
generated_tokens = torch.zeros((num_return_sequences, max_new_tokens), dtype=torch.int).cuda()
|
| 241 |
+
answer_tokens_list = self.answer_text_qwen_vl(prompt, do_sample=False)
|
| 242 |
+
|
| 243 |
+
if answer_tokens_list:
|
| 244 |
+
answer_tokens_list = self.tokenizer.encode(answer_tokens_list, add_special_tokens=False)
|
| 245 |
+
answer_tokens = torch.tensor([answer_tokens_list], device=self.device) # [1, seq_len]
|
| 246 |
+
magic_prompt = " Ultra HD, 4K, cinematic composition"
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
magic_prompt_tokens = self.tokenizer.encode(magic_prompt, add_special_tokens=False)
|
| 250 |
+
magic_prompt_tensor = torch.tensor([magic_prompt_tokens], device=self.device) # [1, magic_seq_len]
|
| 251 |
+
|
| 252 |
+
answer_tokens = torch.cat([answer_tokens, magic_prompt_tensor], dim=1) # [1, seq_len + magic_seq_len]
|
| 253 |
+
answer_prompt = self.tokenizer.decode(answer_tokens[0]).split("assistant\n")[-1] #hjc see
|
| 254 |
+
|
| 255 |
+
special_token = self.special_tokens.get(keys[4])
|
| 256 |
+
special_token_tensor = torch.tensor([[special_token]], device=self.device)
|
| 257 |
+
special_token_expanded = special_token_tensor.expand(-1, answer_tokens.size(1))
|
| 258 |
+
|
| 259 |
+
answer_tokens_with_special = torch.cat([answer_tokens, special_token_expanded], dim=0)
|
| 260 |
+
|
| 261 |
+
batch_size = tokens.size(0) # num_return_sequences*2
|
| 262 |
+
answer_tokens_expanded = answer_tokens_with_special.repeat(batch_size // 2, 1)
|
| 263 |
+
|
| 264 |
+
input_tokens = torch.cat((tokens[:, :14], answer_tokens_expanded, tokens[:, -6:]), dim=1)
|
| 265 |
+
|
| 266 |
+
else:
|
| 267 |
+
input_tokens = tokens
|
| 268 |
+
answer_prompt = None
|
| 269 |
+
|
| 270 |
+
inputs_embeds = self.llm.model.embed_tokens(input_tokens).to(self.device)
|
| 271 |
+
|
| 272 |
+
for i in range(max_new_tokens):
|
| 273 |
+
outputs = self.llm.model(
|
| 274 |
+
inputs_embeds=inputs_embeds,
|
| 275 |
+
use_cache=True,
|
| 276 |
+
past_key_values=outputs.past_key_values if i != 0 else None,
|
| 277 |
+
output_hidden_states=True)
|
| 278 |
+
last_hidden_states = outputs[0]
|
| 279 |
+
|
| 280 |
+
output_states = self.stacked_ar.model(
|
| 281 |
+
inputs_embeds=last_hidden_states,
|
| 282 |
+
past_key_values=output_states.past_key_values if i != 0 else None,
|
| 283 |
+
output_hidden_states=True,
|
| 284 |
+
use_cache=True)
|
| 285 |
+
|
| 286 |
+
last_hidden_states = output_states.hidden_states[-1]
|
| 287 |
+
|
| 288 |
+
logits = self.pixel_output_head(last_hidden_states[:, -1, :])
|
| 289 |
+
logit_cond = logits[0::2, :]
|
| 290 |
+
logit_uncond = logits[1::2, :]
|
| 291 |
+
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
|
| 292 |
+
next_token, _ = self.sample(logits, temperature=1.0, top_k=topk_sample, top_p=topp_sample)
|
| 293 |
+
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
| 294 |
+
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
| 295 |
+
|
| 296 |
+
vqgan_embeds = self.pixel_encoder.get_codebook_entry(next_token)
|
| 297 |
+
img_embeds = self.pixel_adapter(vqgan_embeds)
|
| 298 |
+
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
| 299 |
+
|
| 300 |
+
latent_size = int(math.sqrt(max_new_tokens))
|
| 301 |
+
output_images = self.pixel_encoder.decode_code(generated_tokens.to(dtype=torch.int), shape=[num_return_sequences, self.pixel_encoder.config.codebook_embed_dim, latent_size, latent_size])
|
| 302 |
+
output_images = output_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
| 303 |
+
|
| 304 |
+
diff_images = None
|
| 305 |
+
if self.diffusion_decoder is not None:
|
| 306 |
+
gen_image_embeds = self.pixel_encoder.get_codebook_entry(generated_tokens)
|
| 307 |
+
diff_prompt = answer_prompt if answer_prompt else prompt
|
| 308 |
+
if self.args.diffusion_resolution==512:
|
| 309 |
+
self.diffusion_decoder.pipe.transformer.config.sample_size=16
|
| 310 |
+
elif self.args.diffusion_resolution==1024:
|
| 311 |
+
self.diffusion_decoder.pipe.transformer.config.sample_size=32
|
| 312 |
+
diff_images = self.diffusion_decoder.pipe(
|
| 313 |
+
diff_prompt,
|
| 314 |
+
num_inference_steps=40,
|
| 315 |
+
guidance_scale=4.5,
|
| 316 |
+
gen_image_embeds=gen_image_embeds, #gen_image_embeds,
|
| 317 |
+
control_emd="text",
|
| 318 |
+
ori_inp_way=self.diffusion_decoder.transformer.ori_inp_dit,
|
| 319 |
+
only_t2i="vqconcat",
|
| 320 |
+
img_guidance_scale=1.05,
|
| 321 |
+
height=self.args.diffusion_resolution,
|
| 322 |
+
width=self.args.diffusion_resolution
|
| 323 |
+
).images
|
| 324 |
+
if return_dict:
|
| 325 |
+
return {"output_images":output_images,"generated_tokens":generated_tokens,"diff_images":diff_images,"answer_prompt":answer_prompt}
|
| 326 |
+
return output_images
|
| 327 |
+
|
| 328 |
+
def generate_images_edit(self, image, prompt, max_new_tokens=256, num_return_sequences=1, cfg_weight=5.0, topk_sample=1000, topp_sample=1.0, temperature=1.0,return_dict=False):
|
| 329 |
+
|
| 330 |
+
vq_image_transform = get_vq_transform(self.args)
|
| 331 |
+
full_image_transform = get_full_transform(self.args)
|
| 332 |
+
|
| 333 |
+
if isinstance(image, str):
|
| 334 |
+
image = Image.open(image).convert('RGB')
|
| 335 |
+
elif isinstance(image, list):
|
| 336 |
+
image = [each_image.convert('RGB') for each_image in image]
|
| 337 |
+
else:
|
| 338 |
+
image = image.convert('RGB')
|
| 339 |
+
|
| 340 |
+
messages = [{'role': 'user', 'content': prompt}]
|
| 341 |
+
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 342 |
+
text_token = self.tokenizer.encode(text)
|
| 343 |
+
text_token = torch.tensor(text_token).long().to(self.device)
|
| 344 |
+
|
| 345 |
+
keys = list(self.special_tokens.keys())
|
| 346 |
+
start_token = (torch.ones(1) * self.special_tokens.get(keys[0])).long().to(self.device)
|
| 347 |
+
user_prompt = "<|im_start|>user\n"
|
| 348 |
+
user_prompt_token = self.tokenizer.encode(user_prompt, add_special_tokens=False)
|
| 349 |
+
user_prompt_tensor = torch.tensor(user_prompt_token).long().to(self.device)
|
| 350 |
+
windows = text_token.unfold(0, len(user_prompt_tensor), 1)
|
| 351 |
+
matches = (windows == user_prompt_tensor).all(dim=1)
|
| 352 |
+
image_position = torch.where(matches)[0][0].item() + len(user_prompt_tensor)
|
| 353 |
+
|
| 354 |
+
input_ids = torch.cat((text_token, start_token)).long().to(self.device)
|
| 355 |
+
tokens = torch.zeros((num_return_sequences*2, len(input_ids)), dtype=torch.int).cuda()
|
| 356 |
+
assistant_tokens = input_ids[-4:]
|
| 357 |
+
|
| 358 |
+
for i in range(num_return_sequences*2):
|
| 359 |
+
tokens[i, :] = input_ids
|
| 360 |
+
if i % 2 != 0:
|
| 361 |
+
tokens[i, 1:-1] = self.special_tokens.get(keys[4])
|
| 362 |
+
tokens[i, -4:] = assistant_tokens
|
| 363 |
+
|
| 364 |
+
inputs_embeds = self.llm.model.embed_tokens(tokens).to(self.device)
|
| 365 |
+
position_ids = None
|
| 366 |
+
|
| 367 |
+
if image is not None:
|
| 368 |
+
image_info = preprocess_image_gen(image, self.image_processor, vq_image_transform)
|
| 369 |
+
image_embeds = self.llm.visual(image_info["pixel_values"].to(inputs_embeds.device,self.llm.visual.dtype), grid_thw=image_info["image_grid_thw"].to(inputs_embeds.device))
|
| 370 |
+
image_embeds = image_embeds[None,:].repeat(2, 1, 1).to(inputs_embeds.device, inputs_embeds.dtype)
|
| 371 |
+
|
| 372 |
+
vq_pixel_values = image_info["vq_pixel_values"].to(inputs_embeds.device)
|
| 373 |
+
B = inputs_embeds.size(0)
|
| 374 |
+
if len(vq_pixel_values.shape)==4:
|
| 375 |
+
vq_pixel_values = vq_pixel_values[:,None]
|
| 376 |
+
N = vq_pixel_values.size(1)
|
| 377 |
+
_, _, [_, _, vq_indices] = self.pixel_encoder.encode(vq_pixel_values.flatten(0, 1).bfloat16())
|
| 378 |
+
batch_size = vq_pixel_values.shape[0]
|
| 379 |
+
vq_indices = vq_indices.reshape(batch_size, N, vq_indices.shape[-1])
|
| 380 |
+
vqgan_dec_embeds = self.pixel_encoder.get_codebook_entry(vq_indices)
|
| 381 |
+
vq_embeds = self.pixel_adapter(vqgan_dec_embeds)
|
| 382 |
+
vq_embeds = vq_embeds.repeat(B, 1, 1, 1).to(inputs_embeds.device, inputs_embeds.dtype).flatten(1, 2)
|
| 383 |
+
|
| 384 |
+
vision_start_embeds = self.llm.model.embed_tokens(torch.tensor(self.tokenizer.encode("<|vision_start|>")).long().to(self.device))
|
| 385 |
+
vision_end_embeds = self.llm.model.embed_tokens(torch.tensor(self.tokenizer.encode("<|vision_end|>")).long().to(self.device))
|
| 386 |
+
newline_embeds = self.llm.model.embed_tokens(torch.tensor(self.tokenizer.encode("\n")).long().to(self.device))
|
| 387 |
+
vision_start_embeds = vision_start_embeds.unsqueeze(0).repeat(B, 1, 1)
|
| 388 |
+
vision_end_embeds = vision_end_embeds.unsqueeze(0).repeat(B, 1, 1)
|
| 389 |
+
newline_embeds = newline_embeds.unsqueeze(0).repeat(B, 1, 1)
|
| 390 |
+
|
| 391 |
+
inputs_embeds = torch.cat((inputs_embeds[:, :image_position],
|
| 392 |
+
vision_start_embeds, vq_embeds, vision_end_embeds,
|
| 393 |
+
vision_start_embeds, image_embeds, vision_end_embeds, newline_embeds,
|
| 394 |
+
inputs_embeds[:, image_position:]), dim=1)
|
| 395 |
+
|
| 396 |
+
SPECIAL_VQ_TOKEN = '<|vision_pad|>'
|
| 397 |
+
SPECIAL_VIT_TOKEN = '<|image_pad|>'
|
| 398 |
+
SPECIAL_VQ_TOKEN_ID = self.tokenizer.encode(SPECIAL_VQ_TOKEN)[0]
|
| 399 |
+
SPECIAL_VIT_TOKEN_ID = self.tokenizer.encode(SPECIAL_VIT_TOKEN)[0]
|
| 400 |
+
input_ids_for_position = torch.cat([input_ids[:image_position],
|
| 401 |
+
torch.tensor(self.tokenizer.encode("<|vision_start|>")).to(vq_embeds.device), torch.full((vq_embeds.shape[-2],), SPECIAL_VQ_TOKEN_ID, device=vq_embeds.device), torch.tensor(self.tokenizer.encode("<|vision_end|>")).to(vq_embeds.device),
|
| 402 |
+
torch.tensor(self.tokenizer.encode("<|vision_start|>")).to(vq_embeds.device), torch.full((image_embeds.shape[-2],), SPECIAL_VIT_TOKEN_ID, device=vq_embeds.device), torch.tensor(self.tokenizer.encode("<|vision_end|>")).to(vq_embeds.device), torch.tensor(self.tokenizer.encode("\n")).to(vq_embeds.device),
|
| 403 |
+
input_ids[image_position:],torch.full((vq_embeds.shape[-2],), SPECIAL_VQ_TOKEN_ID, device=vq_embeds.device)], dim=0)
|
| 404 |
+
position_ids, _ = get_rope_index_25(
|
| 405 |
+
self.image_processor.merge_size,
|
| 406 |
+
input_ids_for_position[None],
|
| 407 |
+
image_grid_thw=image_info["image_grid_thw"],
|
| 408 |
+
video_grid_thw=None,
|
| 409 |
+
second_per_grid_ts=None,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
generated_tokens = torch.zeros((num_return_sequences, max_new_tokens), dtype=torch.int).cuda()
|
| 413 |
+
|
| 414 |
+
for i in range(max_new_tokens):
|
| 415 |
+
if i != 0:
|
| 416 |
+
real_position = position_ids[:,:,outputs.past_key_values.seen_tokens:(outputs.past_key_values.seen_tokens+inputs_embeds.shape[1])].to(inputs_embeds.device)
|
| 417 |
+
else:
|
| 418 |
+
real_position = position_ids[:,:,:inputs_embeds.shape[1]].to(inputs_embeds.device)
|
| 419 |
+
outputs = self.llm.model(
|
| 420 |
+
inputs_embeds=inputs_embeds,
|
| 421 |
+
use_cache=True,
|
| 422 |
+
position_ids = real_position,
|
| 423 |
+
past_key_values=outputs.past_key_values if i != 0 else None,
|
| 424 |
+
output_hidden_states=True)
|
| 425 |
+
last_hidden_states = outputs[0]
|
| 426 |
+
|
| 427 |
+
output_states = self.stacked_ar.model(
|
| 428 |
+
inputs_embeds=last_hidden_states,
|
| 429 |
+
past_key_values=output_states.past_key_values if i != 0 else None,
|
| 430 |
+
output_hidden_states=True,
|
| 431 |
+
position_ids = real_position,
|
| 432 |
+
use_cache=True)
|
| 433 |
+
|
| 434 |
+
last_hidden_states = output_states.hidden_states[-1]
|
| 435 |
+
|
| 436 |
+
logits = self.pixel_output_head(last_hidden_states[:, -1, :])
|
| 437 |
+
logit_cond = logits[0::2, :]
|
| 438 |
+
logit_uncond = logits[1::2, :]
|
| 439 |
+
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
|
| 440 |
+
next_token, _ = self.sample(logits, temperature=1.0, top_k=topk_sample, top_p=topp_sample)
|
| 441 |
+
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
| 442 |
+
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
vqgan_embeds = self.pixel_encoder.get_codebook_entry(next_token)
|
| 446 |
+
img_embeds = self.pixel_adapter(vqgan_embeds)
|
| 447 |
+
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
| 448 |
+
|
| 449 |
+
latent_size = int(math.sqrt(max_new_tokens))
|
| 450 |
+
output_images = self.pixel_encoder.decode_code(generated_tokens.to(dtype=torch.int), shape=[num_return_sequences, self.pixel_encoder.config.codebook_embed_dim, latent_size, latent_size])
|
| 451 |
+
output_images = output_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
| 452 |
+
|
| 453 |
+
diff_images = None
|
| 454 |
+
if self.diffusion_decoder is not None:
|
| 455 |
+
|
| 456 |
+
gen_image_embeds = self.pixel_encoder.get_codebook_entry(generated_tokens)
|
| 457 |
+
|
| 458 |
+
if isinstance(image, list):
|
| 459 |
+
processed_img = [full_image_transform(each_image) for each_image in image]
|
| 460 |
+
else:
|
| 461 |
+
processed_img = [full_image_transform(image)]
|
| 462 |
+
if self.args.diffusion_resolution==512:
|
| 463 |
+
self.diffusion_decoder.pipe.transformer.config.sample_size=16
|
| 464 |
+
elif self.args.diffusion_resolution==1024:
|
| 465 |
+
self.diffusion_decoder.pipe.transformer.config.sample_size=32
|
| 466 |
+
diff_images = self.diffusion_decoder.pipe(
|
| 467 |
+
prompt,
|
| 468 |
+
num_inference_steps=50,
|
| 469 |
+
guidance_scale=3.0,
|
| 470 |
+
gen_image_embeds=gen_image_embeds, #gen_image_embeds,
|
| 471 |
+
control_emd="text",ori_inp_img=processed_img[0],ori_inp_way="seq",
|
| 472 |
+
only_t2i="vqconcat",img_guidance_scale=1.8,vq_guidance_scale=1,height=self.args.diffusion_resolution,width=self.args.diffusion_resolution
|
| 473 |
+
).images
|
| 474 |
+
if return_dict:
|
| 475 |
+
return {"output_images": output_images, "generated_tokens": None, "diff_images": diff_images}
|
| 476 |
+
return None
|
| 477 |
+
|
| 478 |
+
def sample(self, logits, temperature: float=1.0, top_k: int=0, top_p: float=1.0, sample_logits=True):
|
| 479 |
+
|
| 480 |
+
logits = logits / max(temperature, 1e-5)
|
| 481 |
+
if top_k > 0 or top_p < 1.0:
|
| 482 |
+
logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
| 483 |
+
probs = F.softmax(logits, dim=-1)
|
| 484 |
+
if sample_logits:
|
| 485 |
+
idx = torch.multinomial(probs, num_samples=1)
|
| 486 |
+
else:
|
| 487 |
+
_, idx = torch.topk(probs, k=1, dim=-1)
|
| 488 |
+
return idx, probs
|
| 489 |
+
|
| 490 |
+
def top_k_top_p_filtering(
|
| 491 |
+
self,
|
| 492 |
+
logits,
|
| 493 |
+
top_k: int = 0,
|
| 494 |
+
top_p: float = 1.0,
|
| 495 |
+
filter_value: float = -float("Inf"),
|
| 496 |
+
min_tokens_to_keep: int = 1,
|
| 497 |
+
):
|
| 498 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
| 499 |
+
"""
|
| 500 |
+
if top_k > 0:
|
| 501 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
| 502 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
| 503 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 504 |
+
logits[indices_to_remove] = filter_value
|
| 505 |
+
|
| 506 |
+
if top_p < 1.0:
|
| 507 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 508 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 509 |
+
|
| 510 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
| 511 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 512 |
+
if min_tokens_to_keep > 1:
|
| 513 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
| 514 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
| 515 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
| 516 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 517 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 518 |
+
|
| 519 |
+
# scatter sorted tensors to original indexing
|
| 520 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 521 |
+
logits[indices_to_remove] = filter_value
|
| 522 |
+
return logits
|
| 523 |
+
|
| 524 |
+
# For Inference Understand
|
| 525 |
+
def preprocess_image(self, image):
|
| 526 |
+
if image is None:
|
| 527 |
+
return None
|
| 528 |
+
if isinstance(image, str):
|
| 529 |
+
if os.path.exists(image):
|
| 530 |
+
pil_image = Image.open(image).convert('RGB')
|
| 531 |
+
else:
|
| 532 |
+
response = requests.get(image)
|
| 533 |
+
if response.status_code == 200:
|
| 534 |
+
image_bytes = BytesIO(response.content)
|
| 535 |
+
pil_image = Image.open(image_bytes).convert('RGB')
|
| 536 |
+
else:
|
| 537 |
+
raise ValueError(f"Failed to load image from url {image}")
|
| 538 |
+
elif isinstance(image, Image.Image):
|
| 539 |
+
pil_image = image.convert('RGB')
|
| 540 |
+
elif isinstance(image, list):
|
| 541 |
+
return self.preprocess_image(image[0])
|
| 542 |
+
else:
|
| 543 |
+
raise ValueError("Unsupported image type")
|
| 544 |
+
|
| 545 |
+
return pil_image
|
| 546 |
+
|
| 547 |
+
def inference_understand(self, image, question, max_new_tokens=256):
|
| 548 |
+
pil_image = self.preprocess_image(image)
|
| 549 |
+
|
| 550 |
+
messages = [
|
| 551 |
+
{
|
| 552 |
+
"role": "user",
|
| 553 |
+
"content": [
|
| 554 |
+
{
|
| 555 |
+
"type": "image",
|
| 556 |
+
"image": pil_image,
|
| 557 |
+
},
|
| 558 |
+
{"type": "text", "text": question},
|
| 559 |
+
],
|
| 560 |
+
}
|
| 561 |
+
]
|
| 562 |
+
|
| 563 |
+
from qwen_vl_utils import process_vision_info
|
| 564 |
+
# Preparation for inference
|
| 565 |
+
text = self.processor.apply_chat_template(
|
| 566 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 567 |
+
)
|
| 568 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 569 |
+
inputs = self.processor(
|
| 570 |
+
text=[text],
|
| 571 |
+
images=image_inputs,
|
| 572 |
+
videos=video_inputs,
|
| 573 |
+
padding=True,
|
| 574 |
+
return_tensors="pt",
|
| 575 |
+
)
|
| 576 |
+
inputs = inputs.to(self.llm.device)
|
| 577 |
+
|
| 578 |
+
# Inference: Generation of the output
|
| 579 |
+
generated_ids = self.llm.generate(**inputs, max_new_tokens=max_new_tokens)
|
| 580 |
+
generated_ids_trimmed = [
|
| 581 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 582 |
+
]
|
| 583 |
+
output_text = self.processor.batch_decode(
|
| 584 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
return output_text[0] if output_text else ""
|
star/models/pixel_decoder/lumina2_decoder.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, Lumina2Pipeline
|
| 3 |
+
from transformers import AutoTokenizer, Gemma2Model
|
| 4 |
+
import copy
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from diffusers.training_utils import (
|
| 8 |
+
cast_training_params,
|
| 9 |
+
compute_density_for_timestep_sampling,
|
| 10 |
+
compute_loss_weighting_for_sd3,
|
| 11 |
+
free_memory,
|
| 12 |
+
)
|
| 13 |
+
from diffusers.pipelines.lumina2.pipeline_lumina2 import *
|
| 14 |
+
|
| 15 |
+
class Lumina2Decoder(torch.nn.Module):
|
| 16 |
+
def __init__(self, config, args):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.diffusion_model_path = config.model_path
|
| 19 |
+
|
| 20 |
+
if not hasattr(args, "revision"):
|
| 21 |
+
args.revision = None
|
| 22 |
+
if not hasattr(args, "variant"):
|
| 23 |
+
args.variant = None
|
| 24 |
+
|
| 25 |
+
self.tokenizer_one = AutoTokenizer.from_pretrained(
|
| 26 |
+
self.diffusion_model_path,
|
| 27 |
+
subfolder="tokenizer",
|
| 28 |
+
revision=args.revision,
|
| 29 |
+
)
|
| 30 |
+
self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 31 |
+
self.diffusion_model_path, subfolder="scheduler"
|
| 32 |
+
)
|
| 33 |
+
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
|
| 34 |
+
self.text_encoder_one = Gemma2Model.from_pretrained(
|
| 35 |
+
self.diffusion_model_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
| 36 |
+
)
|
| 37 |
+
self.text_encoding_pipeline = Lumina2Pipeline.from_pretrained(
|
| 38 |
+
self.diffusion_model_path,
|
| 39 |
+
vae=None,
|
| 40 |
+
transformer=None,
|
| 41 |
+
text_encoder=self.text_encoder_one,
|
| 42 |
+
tokenizer=self.tokenizer_one,
|
| 43 |
+
)
|
| 44 |
+
self.vae = AutoencoderKL.from_pretrained(
|
| 45 |
+
self.diffusion_model_path,
|
| 46 |
+
subfolder="vae",
|
| 47 |
+
revision=args.revision,
|
| 48 |
+
variant=args.variant,
|
| 49 |
+
)
|
| 50 |
+
if args.ori_inp_dit=="seq":
|
| 51 |
+
from star.models.pixel_decoder.transformer_lumina2_seq import Lumina2Transformer2DModel
|
| 52 |
+
elif args.ori_inp_dit=="ref":
|
| 53 |
+
from star.models.pixel_decoder.transformer_lumina2 import Lumina2Transformer2DModel
|
| 54 |
+
|
| 55 |
+
self.transformer = Lumina2Transformer2DModel.from_pretrained(
|
| 56 |
+
self.diffusion_model_path, subfolder="transformer", revision=args.revision, variant=args.variant
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
vq_dim = 512
|
| 60 |
+
patch_size = self.transformer.config.patch_size
|
| 61 |
+
in_channels = vq_dim + self.transformer.config.in_channels # 48 for mask
|
| 62 |
+
out_channels = self.transformer.x_embedder.out_features
|
| 63 |
+
|
| 64 |
+
load_num_channel = self.transformer.config.in_channels * patch_size * patch_size
|
| 65 |
+
self.transformer.register_to_config(in_channels=in_channels)
|
| 66 |
+
transformer = self.transformer
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
new_proj = nn.Linear(
|
| 69 |
+
in_channels * patch_size * patch_size, out_channels, bias=True
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
new_proj.weight.zero_()
|
| 73 |
+
|
| 74 |
+
new_proj = new_proj.to(transformer.x_embedder.weight.dtype)
|
| 75 |
+
new_proj.weight[:, :load_num_channel].copy_(transformer.x_embedder.weight)
|
| 76 |
+
new_proj.bias.copy_(transformer.x_embedder.bias)
|
| 77 |
+
transformer.x_embedder = new_proj
|
| 78 |
+
|
| 79 |
+
self.ori_inp_dit = args.ori_inp_dit
|
| 80 |
+
if args.ori_inp_dit=="seq":
|
| 81 |
+
refiner_channels = transformer.noise_refiner[-1].dim
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
vae2cond_proj1 = nn.Linear(refiner_channels, refiner_channels, bias=True)
|
| 84 |
+
vae2cond_act = nn.GELU(approximate='tanh')
|
| 85 |
+
vae2cond_proj2 = nn.Linear(refiner_channels, refiner_channels, bias=False)
|
| 86 |
+
vae2cond_proj2.weight.zero_()
|
| 87 |
+
|
| 88 |
+
ori_inp_refiner = nn.Sequential(
|
| 89 |
+
vae2cond_proj1,
|
| 90 |
+
vae2cond_act,
|
| 91 |
+
vae2cond_proj2
|
| 92 |
+
)
|
| 93 |
+
transformer.ori_inp_refiner = ori_inp_refiner
|
| 94 |
+
transformer.ori_inp_dit = self.ori_inp_dit
|
| 95 |
+
elif args.ori_inp_dit=="ref":
|
| 96 |
+
transformer.initialize_ref_weights()
|
| 97 |
+
transformer.ori_inp_dit = self.ori_inp_dit
|
| 98 |
+
|
| 99 |
+
transformer.requires_grad_(True)
|
| 100 |
+
|
| 101 |
+
if args.grad_ckpt and args.diffusion_resolution==1024:
|
| 102 |
+
transformer.gradient_checkpointing = args.grad_ckpt
|
| 103 |
+
transformer.enable_gradient_checkpointing()
|
| 104 |
+
|
| 105 |
+
self.vae.requires_grad_(False)
|
| 106 |
+
self.vae.to(dtype=torch.float32)
|
| 107 |
+
self.args = args
|
| 108 |
+
|
| 109 |
+
self.pipe = Lumina2InstructPix2PixPipeline.from_pretrained(self.diffusion_model_path,
|
| 110 |
+
transformer=transformer,
|
| 111 |
+
text_encoder=self.text_encoder_one,
|
| 112 |
+
vae=self.vae,
|
| 113 |
+
torch_dtype=torch.bfloat16)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
_, _, self.uncond_prompt_embeds, self.uncond_prompt_attention_mask = self.text_encoding_pipeline.encode_prompt(
|
| 118 |
+
"",
|
| 119 |
+
max_sequence_length=self.args.max_diff_seq_length,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def compute_text_embeddings(self,prompt, text_encoding_pipeline):
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(
|
| 125 |
+
prompt,
|
| 126 |
+
max_sequence_length=self.args.max_diff_seq_length,
|
| 127 |
+
)
|
| 128 |
+
return prompt_embeds, prompt_attention_mask
|
| 129 |
+
|
| 130 |
+
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
|
| 131 |
+
sigmas = self.noise_scheduler_copy.sigmas.to(dtype=dtype)
|
| 132 |
+
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device=timesteps.device)
|
| 133 |
+
timesteps = timesteps
|
| 134 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 135 |
+
|
| 136 |
+
sigma = sigmas[step_indices].flatten()
|
| 137 |
+
while len(sigma.shape) < n_dim:
|
| 138 |
+
sigma = sigma.unsqueeze(-1)
|
| 139 |
+
return sigma
|
| 140 |
+
|
| 141 |
+
def forward(self, batch_gpu,batch, image_embeds):
|
| 142 |
+
args = self.args
|
| 143 |
+
pixel_values = batch_gpu["full_pixel_values"].to(dtype=self.vae.dtype) #aux_image
|
| 144 |
+
data_type = "t2i"
|
| 145 |
+
if len(pixel_values.shape)==5:
|
| 146 |
+
bs,num_img,c,h,w = pixel_values.shape
|
| 147 |
+
if num_img==2:
|
| 148 |
+
data_type = "edit"
|
| 149 |
+
pixel_values_ori_img = pixel_values[:,0]
|
| 150 |
+
pixel_values = pixel_values[:,-1]
|
| 151 |
+
pixel_values = F.interpolate(pixel_values, size=(self.args.diffusion_resolution, self.args.diffusion_resolution), mode='bilinear',align_corners=False)
|
| 152 |
+
if data_type=="edit" and self.ori_inp_dit!="none":
|
| 153 |
+
pixel_values_ori_img = F.interpolate(pixel_values_ori_img, size=(self.args.diffusion_resolution, self.args.diffusion_resolution), mode='bilinear', align_corners=False)
|
| 154 |
+
prompt = batch["prompts"]
|
| 155 |
+
bs,_,_,_ = pixel_values.shape
|
| 156 |
+
image_prompt_embeds = None
|
| 157 |
+
image_embeds_2d = image_embeds.reshape(bs, 24, 24, image_embeds.shape[-1]).permute(0, 3, 1, 2)
|
| 158 |
+
image_embeds_2d = F.interpolate(image_embeds_2d, size=(args.diffusion_resolution//8, args.diffusion_resolution//8), mode='bilinear', align_corners=False)
|
| 159 |
+
|
| 160 |
+
control_emd = args.control_emd
|
| 161 |
+
prompt_embeds, prompt_attention_mask = self.compute_text_embeddings(prompt, self.text_encoding_pipeline)
|
| 162 |
+
if control_emd=="mix":
|
| 163 |
+
prompt_embeds=torch.cat([prompt_embeds, image_prompt_embeds], dim=1) #use mix
|
| 164 |
+
elif control_emd=="null":
|
| 165 |
+
prompt_embeds = torch.zeros_like(prompt_embeds)
|
| 166 |
+
prompt_attention_mask = torch.ones_like(prompt_attention_mask)
|
| 167 |
+
elif control_emd=="text":
|
| 168 |
+
pass
|
| 169 |
+
elif control_emd=="vit" or control_emd=="vq" or control_emd=="vqvae" or control_emd=="vqconcat" or control_emd=="vqconcatvit":
|
| 170 |
+
prompt_embeds=image_prompt_embeds
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
latents = self.vae.encode(pixel_values).latent_dist.sample()
|
| 174 |
+
latents = (latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 175 |
+
latents = latents.to(dtype=image_embeds.dtype)
|
| 176 |
+
|
| 177 |
+
latents_ori_img = torch.zeros_like(latents)
|
| 178 |
+
if data_type=="edit" and self.ori_inp_dit!="none":
|
| 179 |
+
latents_ori_img = self.vae.encode(pixel_values_ori_img).latent_dist.sample()
|
| 180 |
+
latents_ori_img = (latents_ori_img - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 181 |
+
latents_ori_img = latents_ori_img.to(dtype=image_embeds.dtype)
|
| 182 |
+
|
| 183 |
+
# Sample noise that we'll add to the latents
|
| 184 |
+
noise = torch.randn_like(latents)
|
| 185 |
+
bsz = latents.shape[0]
|
| 186 |
+
# Sample a random timestep for each image
|
| 187 |
+
# for weighting schemes where we sample timesteps non-uniformly
|
| 188 |
+
u = compute_density_for_timestep_sampling(
|
| 189 |
+
weighting_scheme=args.weighting_scheme,
|
| 190 |
+
batch_size=bsz,
|
| 191 |
+
logit_mean=args.logit_mean,
|
| 192 |
+
logit_std=args.logit_std,
|
| 193 |
+
mode_scale=args.mode_scale,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
| 197 |
+
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
|
| 198 |
+
|
| 199 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
| 200 |
+
# (this is the forward diffusion process)
|
| 201 |
+
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype).to(device=noise.device)
|
| 202 |
+
#noisy_model_input = (1.0 - sigmas) * noise + sigmas * latents
|
| 203 |
+
noisy_model_input = sigmas * noise + (1-sigmas) * latents
|
| 204 |
+
#noisy_model_input + (1-sigmas)*(latents - noise) = latents
|
| 205 |
+
# Get the additional image embedding for conditioning.
|
| 206 |
+
# Instead of getting a diagonal Gaussian here, we simply take the mode.
|
| 207 |
+
original_image_embeds = image_embeds_2d
|
| 208 |
+
|
| 209 |
+
if args.conditioning_dropout_prob is not None:
|
| 210 |
+
random_p = torch.rand(bsz, device=latents.device)
|
| 211 |
+
# Sample masks for the edit prompts.
|
| 212 |
+
prompt_mask = random_p < 2 * args.uncondition_prob
|
| 213 |
+
prompt_mask = prompt_mask.reshape(bsz, 1, 1)
|
| 214 |
+
# Final text conditioning.
|
| 215 |
+
#prompt_embeds = torch.where(prompt_mask, torch.zeros_like(prompt_embeds), prompt_embeds)
|
| 216 |
+
prompt_embeds = torch.where(prompt_mask, self.uncond_prompt_embeds.repeat(prompt_embeds.shape[0],1,1).to(prompt_embeds.device), prompt_embeds)
|
| 217 |
+
prompt_attention_mask = torch.where(prompt_mask[:,0], self.uncond_prompt_attention_mask.repeat(prompt_embeds.shape[0],1).to(prompt_embeds.device), prompt_attention_mask)
|
| 218 |
+
|
| 219 |
+
# Sample masks for the original images.
|
| 220 |
+
#random_p_vq = torch.rand(bsz, device=latents.device)
|
| 221 |
+
image_mask_dtype = original_image_embeds.dtype
|
| 222 |
+
image_mask = 1 - (
|
| 223 |
+
(random_p <= args.conditioning_dropout_prob).to(image_mask_dtype)
|
| 224 |
+
)
|
| 225 |
+
image_mask = image_mask.reshape(bsz, 1, 1, 1)
|
| 226 |
+
|
| 227 |
+
if data_type=="edit":
|
| 228 |
+
image_mask=0
|
| 229 |
+
# Final image conditioning.
|
| 230 |
+
original_image_embeds = image_mask * original_image_embeds
|
| 231 |
+
|
| 232 |
+
ori_latent_mask = 1 - (
|
| 233 |
+
(random_p >= args.uncondition_prob).to(image_mask_dtype)
|
| 234 |
+
* (random_p < 3 * args.uncondition_prob).to(image_mask_dtype)
|
| 235 |
+
)
|
| 236 |
+
ori_latent_mask = ori_latent_mask.reshape(bsz, 1, 1, 1)
|
| 237 |
+
latents_ori_img = ori_latent_mask * latents_ori_img
|
| 238 |
+
|
| 239 |
+
concatenated_noisy_latents = torch.cat([noisy_model_input, original_image_embeds], dim=1)
|
| 240 |
+
|
| 241 |
+
ref_image_hidden_states = None
|
| 242 |
+
if self.ori_inp_dit=="dim":
|
| 243 |
+
concatenated_noisy_latents = torch.cat([concatenated_noisy_latents, latents_ori_img], dim=1)
|
| 244 |
+
elif self.ori_inp_dit=="seq":
|
| 245 |
+
latents_ori_img = torch.cat([latents_ori_img, original_image_embeds], dim=1)
|
| 246 |
+
concatenated_noisy_latents = torch.cat([concatenated_noisy_latents, latents_ori_img], dim=2)
|
| 247 |
+
elif self.ori_inp_dit=="ref":
|
| 248 |
+
latents_ori_img = torch.cat([latents_ori_img, original_image_embeds], dim=1)
|
| 249 |
+
ref_image_hidden_states = latents_ori_img[:,None]
|
| 250 |
+
# Predict the noise residual
|
| 251 |
+
# scale the timesteps (reversal not needed as we used a reverse lerp above already)
|
| 252 |
+
timesteps = 1-timesteps / self.noise_scheduler.config.num_train_timesteps #timesteps / self.noise_scheduler.config.num_train_timesteps
|
| 253 |
+
model_pred = self.transformer(
|
| 254 |
+
hidden_states=concatenated_noisy_latents,
|
| 255 |
+
timestep=timesteps,
|
| 256 |
+
encoder_hidden_states=prompt_embeds,
|
| 257 |
+
encoder_attention_mask=prompt_attention_mask,
|
| 258 |
+
# ref_image_hidden_states = ref_image_hidden_states,
|
| 259 |
+
return_dict=False,
|
| 260 |
+
)[0]
|
| 261 |
+
if self.ori_inp_dit=="seq":
|
| 262 |
+
model_pred = model_pred[:, :, :args.diffusion_resolution//8, :]
|
| 263 |
+
|
| 264 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
| 265 |
+
target = latents - noise
|
| 266 |
+
# Conditioning dropout to support classifier-free guidance during inference. For more details
|
| 267 |
+
# check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
|
| 268 |
+
|
| 269 |
+
# Concatenate the `original_image_embeds` with the `noisy_latents`.
|
| 270 |
+
|
| 271 |
+
# Get the target for loss depending on the prediction type
|
| 272 |
+
loss = torch.mean(
|
| 273 |
+
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
| 274 |
+
1,
|
| 275 |
+
)
|
| 276 |
+
loss = loss.mean()
|
| 277 |
+
|
| 278 |
+
loss_value = loss.item()
|
| 279 |
+
|
| 280 |
+
return loss
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class Lumina2InstructPix2PixPipeline(Lumina2Pipeline):
|
| 284 |
+
|
| 285 |
+
@torch.no_grad()
|
| 286 |
+
def __call__(
|
| 287 |
+
self,
|
| 288 |
+
prompt: Union[str, List[str]] = None,
|
| 289 |
+
width: Optional[int] = None,
|
| 290 |
+
height: Optional[int] = None,
|
| 291 |
+
num_inference_steps: int = 30,
|
| 292 |
+
guidance_scale: float = 4.0,
|
| 293 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 294 |
+
sigmas: List[float] = None,
|
| 295 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 296 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 297 |
+
latents: Optional[torch.Tensor] = None,
|
| 298 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 299 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 300 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 301 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 302 |
+
output_type: Optional[str] = "pil",
|
| 303 |
+
return_dict: bool = True,
|
| 304 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 305 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 306 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 307 |
+
system_prompt: Optional[str] = None,
|
| 308 |
+
cfg_trunc_ratio=[0.0,1.0],
|
| 309 |
+
cfg_normalization: bool = False,
|
| 310 |
+
max_sequence_length: int = 256,
|
| 311 |
+
control_emd="text",
|
| 312 |
+
img_cfg_trunc_ratio =[0.0,1.0],
|
| 313 |
+
gen_image_embeds=None,only_t2i="vqconcat",image_prompt_embeds=None,ori_inp_img=None,img_guidance_scale=1.5,vq_guidance_scale=0,ori_inp_way="none",
|
| 314 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 315 |
+
|
| 316 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 317 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 318 |
+
self._guidance_scale = guidance_scale
|
| 319 |
+
self._attention_kwargs = attention_kwargs
|
| 320 |
+
|
| 321 |
+
num_images_per_prompt = gen_image_embeds.shape[0] if gen_image_embeds is not None else image_prompt_embeds.shape[0]
|
| 322 |
+
# 1. Check inputs. Raise error if not correct
|
| 323 |
+
self.check_inputs(
|
| 324 |
+
prompt,
|
| 325 |
+
height,
|
| 326 |
+
width,
|
| 327 |
+
negative_prompt,
|
| 328 |
+
prompt_embeds=prompt_embeds,
|
| 329 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 330 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 331 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 332 |
+
max_sequence_length=max_sequence_length,
|
| 333 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# 2. Define call parameters
|
| 337 |
+
if prompt is not None and isinstance(prompt, str):
|
| 338 |
+
batch_size = 1
|
| 339 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 340 |
+
batch_size = len(prompt)
|
| 341 |
+
else:
|
| 342 |
+
batch_size = prompt_embeds.shape[0]
|
| 343 |
+
|
| 344 |
+
device = self._execution_device
|
| 345 |
+
|
| 346 |
+
# 3. Encode input prompt
|
| 347 |
+
(
|
| 348 |
+
prompt_embeds,
|
| 349 |
+
prompt_attention_mask,
|
| 350 |
+
negative_prompt_embeds,
|
| 351 |
+
negative_prompt_attention_mask,
|
| 352 |
+
) = self.encode_prompt(
|
| 353 |
+
prompt,
|
| 354 |
+
self.do_classifier_free_guidance,
|
| 355 |
+
negative_prompt=negative_prompt,
|
| 356 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 357 |
+
device=device,
|
| 358 |
+
prompt_embeds=prompt_embeds,
|
| 359 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 360 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 361 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 362 |
+
max_sequence_length=max_sequence_length,
|
| 363 |
+
system_prompt=system_prompt,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
if gen_image_embeds is not None:
|
| 368 |
+
image_embeds_8=gen_image_embeds
|
| 369 |
+
|
| 370 |
+
if control_emd=="text":
|
| 371 |
+
pass
|
| 372 |
+
elif control_emd=="null":
|
| 373 |
+
prompt_embeds = torch.zeros_like(prompt_embeds)
|
| 374 |
+
prompt_attention_mask = torch.zeros_like(prompt_attention_mask)
|
| 375 |
+
negative_prompt_embeds = prompt_embeds
|
| 376 |
+
negative_prompt_attention_mask = prompt_attention_mask
|
| 377 |
+
|
| 378 |
+
if self.do_classifier_free_guidance:
|
| 379 |
+
prompt_embeds = torch.cat([negative_prompt_embeds,negative_prompt_embeds, prompt_embeds], dim=0)
|
| 380 |
+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask,negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
| 381 |
+
# 4. Prepare latents.
|
| 382 |
+
latent_channels = self.vae.config.latent_channels #self.transformer.config.in_channels
|
| 383 |
+
latents = self.prepare_latents(
|
| 384 |
+
batch_size * num_images_per_prompt,
|
| 385 |
+
latent_channels,
|
| 386 |
+
height,
|
| 387 |
+
width,
|
| 388 |
+
prompt_embeds.dtype,
|
| 389 |
+
device,
|
| 390 |
+
generator,
|
| 391 |
+
latents,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
latents_ori_img = torch.zeros_like(latents)
|
| 395 |
+
if ori_inp_img is not None and ori_inp_way !="none":
|
| 396 |
+
#fuck = torch.load(ori_inp_img).to(latents.device)
|
| 397 |
+
ori_inp_img = F.interpolate(ori_inp_img[None].to(latents.device,latents.dtype), size=(height,width), mode='bilinear',align_corners=False)
|
| 398 |
+
latents_ori_img = self.vae.encode(ori_inp_img).latent_dist.sample()
|
| 399 |
+
latents_ori_img = (latents_ori_img- self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 400 |
+
latents_ori_img = latents_ori_img.to(dtype=latents.dtype)
|
| 401 |
+
if ori_inp_way !="none":
|
| 402 |
+
negative_latents_ori_img = torch.zeros_like(latents_ori_img).to(prompt_embeds.dtype)
|
| 403 |
+
latents_ori_img = torch.cat([negative_latents_ori_img,latents_ori_img, latents_ori_img], dim=0) if self.do_classifier_free_guidance else latents_ori_img
|
| 404 |
+
|
| 405 |
+
vq_in_edit = False
|
| 406 |
+
if only_t2i==True:
|
| 407 |
+
image_latents = torch.zeros_like(latents)[:,:8]
|
| 408 |
+
elif only_t2i=="vqconcat":
|
| 409 |
+
image_embeds_2d = image_embeds_8.reshape(batch_size* num_images_per_prompt,24,24,image_embeds_8.shape[-1]).permute(0,3,1,2)
|
| 410 |
+
if ori_inp_img is not None and image_embeds_8.mean()!=0:
|
| 411 |
+
vq_in_edit = True
|
| 412 |
+
image_vq_latents = F.interpolate(image_embeds_2d, size=(height//8,width//8), mode='bilinear',align_corners=False).to(latents.device,latents.dtype)
|
| 413 |
+
image_latents = torch.zeros_like(image_vq_latents)
|
| 414 |
+
else:
|
| 415 |
+
image_latents = F.interpolate(image_embeds_2d, size=(height//8,width//8), mode='bilinear',align_corners=False).to(latents.device,latents.dtype)
|
| 416 |
+
|
| 417 |
+
negative_image_latents = torch.zeros_like(image_latents).to(prompt_embeds.dtype)
|
| 418 |
+
image_latents = torch.cat([negative_image_latents,image_latents, image_latents], dim=0) if self.do_classifier_free_guidance else image_latents
|
| 419 |
+
|
| 420 |
+
# 5. Prepare timesteps
|
| 421 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 422 |
+
image_seq_len = latents.shape[1]
|
| 423 |
+
mu = calculate_shift(
|
| 424 |
+
image_seq_len,
|
| 425 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 426 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 427 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 428 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 429 |
+
)
|
| 430 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 431 |
+
self.scheduler,
|
| 432 |
+
num_inference_steps,
|
| 433 |
+
device,
|
| 434 |
+
sigmas=sigmas,
|
| 435 |
+
mu=mu,
|
| 436 |
+
)
|
| 437 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 438 |
+
self._num_timesteps = len(timesteps)
|
| 439 |
+
|
| 440 |
+
self.scheduler.sigmas=self.scheduler.sigmas.to(latents.dtype) #hjc find bug
|
| 441 |
+
# 6. Denoising loop
|
| 442 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 443 |
+
for i, t in enumerate(timesteps):
|
| 444 |
+
# compute whether apply classifier-free truncation on this timestep
|
| 445 |
+
do_classifier_free_truncation = not ((i + 1) / num_inference_steps > cfg_trunc_ratio[0] and (i + 1) / num_inference_steps < cfg_trunc_ratio[1])
|
| 446 |
+
img_do_classifier_free_truncation = not ((i + 1) / num_inference_steps > img_cfg_trunc_ratio[0] and (i + 1) / num_inference_steps < img_cfg_trunc_ratio[1])
|
| 447 |
+
|
| 448 |
+
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
|
| 449 |
+
current_timestep = 1 - t / self.scheduler.config.num_train_timesteps
|
| 450 |
+
|
| 451 |
+
latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
|
| 452 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 453 |
+
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
| 454 |
+
|
| 455 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
|
| 456 |
+
|
| 457 |
+
ref_image_hidden_states = None
|
| 458 |
+
if ori_inp_way=="seq":
|
| 459 |
+
latents_ori_img_cat = torch.cat([latents_ori_img, image_latents], dim=1)
|
| 460 |
+
latent_model_input = torch.cat([latent_model_input, latents_ori_img_cat], dim=2)
|
| 461 |
+
elif ori_inp_way=="ref":
|
| 462 |
+
latents_ori_img_cat = torch.cat([latents_ori_img, image_latents], dim=1)
|
| 463 |
+
ref_image_hidden_states = latents_ori_img_cat[:,None]
|
| 464 |
+
|
| 465 |
+
if ori_inp_way=="ref":
|
| 466 |
+
noise_pred = self.transformer(
|
| 467 |
+
hidden_states=latent_model_input,
|
| 468 |
+
timestep=current_timestep,
|
| 469 |
+
encoder_hidden_states=prompt_embeds,
|
| 470 |
+
encoder_attention_mask=prompt_attention_mask,
|
| 471 |
+
return_dict=False,ref_image_hidden_states=ref_image_hidden_states,
|
| 472 |
+
attention_kwargs=self.attention_kwargs,
|
| 473 |
+
)[0]
|
| 474 |
+
else:
|
| 475 |
+
noise_pred = self.transformer(
|
| 476 |
+
hidden_states=latent_model_input,
|
| 477 |
+
timestep=current_timestep,
|
| 478 |
+
encoder_hidden_states=prompt_embeds,
|
| 479 |
+
encoder_attention_mask=prompt_attention_mask,
|
| 480 |
+
return_dict=False,
|
| 481 |
+
attention_kwargs=self.attention_kwargs,
|
| 482 |
+
)[0]
|
| 483 |
+
if ori_inp_way=="seq":
|
| 484 |
+
noise_pred = noise_pred[:,:,:height//8,:]
|
| 485 |
+
|
| 486 |
+
if vq_in_edit:
|
| 487 |
+
latent_model_vq_input = torch.cat([latents, image_vq_latents], dim=1)
|
| 488 |
+
if ori_inp_way=="seq":
|
| 489 |
+
latents_ori_img_cat_vq = torch.cat([torch.zeros_like(latents), image_vq_latents], dim=1)
|
| 490 |
+
latent_model_vq_input = torch.cat([latent_model_vq_input, latents_ori_img_cat_vq], dim=2)
|
| 491 |
+
|
| 492 |
+
noise_vq_pred = self.transformer(
|
| 493 |
+
hidden_states=latent_model_vq_input,
|
| 494 |
+
timestep=current_timestep[-1:],
|
| 495 |
+
encoder_hidden_states=prompt_embeds[-1:],
|
| 496 |
+
encoder_attention_mask=prompt_attention_mask[-1:],
|
| 497 |
+
return_dict=False,
|
| 498 |
+
attention_kwargs=self.attention_kwargs,
|
| 499 |
+
)[0]
|
| 500 |
+
if ori_inp_way=="seq":
|
| 501 |
+
noise_vq_pred = noise_vq_pred[:,:,:height//8,:]
|
| 502 |
+
# perform normalization-based guidance scale on a truncated timestep interval
|
| 503 |
+
if self.do_classifier_free_guidance:
|
| 504 |
+
noise_pred_uncond,noise_pred_img, noise_pred_text = noise_pred.chunk(3)
|
| 505 |
+
if not do_classifier_free_truncation and not img_do_classifier_free_truncation:
|
| 506 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_img)+ img_guidance_scale * (noise_pred_img - noise_pred_uncond)
|
| 507 |
+
elif not do_classifier_free_truncation and img_do_classifier_free_truncation:
|
| 508 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_img)+ 1 * (noise_pred_img - noise_pred_uncond)
|
| 509 |
+
elif do_classifier_free_truncation and not img_do_classifier_free_truncation:
|
| 510 |
+
noise_pred = noise_pred_uncond + 1 * (noise_pred_text - noise_pred_img)+ img_guidance_scale * (noise_pred_img - noise_pred_uncond)
|
| 511 |
+
else:
|
| 512 |
+
noise_pred = noise_pred_text
|
| 513 |
+
if vq_in_edit:
|
| 514 |
+
noise_pred = noise_pred +vq_guidance_scale*(noise_vq_pred-noise_pred_uncond)
|
| 515 |
+
# apply normalization after classifier-free guidance
|
| 516 |
+
if cfg_normalization:
|
| 517 |
+
cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True)
|
| 518 |
+
noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
| 519 |
+
noise_pred = noise_pred * (cond_norm / noise_norm)
|
| 520 |
+
else:
|
| 521 |
+
noise_pred = noise_pred
|
| 522 |
+
|
| 523 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 524 |
+
latents_dtype = latents.dtype
|
| 525 |
+
noise_pred = -noise_pred
|
| 526 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 527 |
+
|
| 528 |
+
if latents.dtype != latents_dtype:
|
| 529 |
+
if torch.backends.mps.is_available():
|
| 530 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 531 |
+
latents = latents.to(latents_dtype)
|
| 532 |
+
|
| 533 |
+
if callback_on_step_end is not None:
|
| 534 |
+
callback_kwargs = {}
|
| 535 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 536 |
+
callback_kwargs[k] = locals()[k]
|
| 537 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 538 |
+
|
| 539 |
+
latents = callback_outputs.pop("latents", latents)
|
| 540 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 541 |
+
|
| 542 |
+
# call the callback, if provided
|
| 543 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 544 |
+
progress_bar.update()
|
| 545 |
+
|
| 546 |
+
if XLA_AVAILABLE:
|
| 547 |
+
xm.mark_step()
|
| 548 |
+
|
| 549 |
+
if not output_type == "latent":
|
| 550 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 551 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 552 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 553 |
+
else:
|
| 554 |
+
image = latents
|
| 555 |
+
|
| 556 |
+
# Offload all models
|
| 557 |
+
self.maybe_free_model_hooks()
|
| 558 |
+
|
| 559 |
+
if not return_dict:
|
| 560 |
+
return (image,)
|
| 561 |
+
|
| 562 |
+
return ImagePipelineOutput(images=image)
|
| 563 |
+
|
star/models/pixel_decoder/transformer_lumina2.py
ADDED
|
@@ -0,0 +1,770 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from einops import rearrange
|
| 22 |
+
from diffusers.models.transformers.transformer_lumina2 import *
|
| 23 |
+
from einops import repeat
|
| 24 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 25 |
+
import itertools
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
hidden_size: int = 4096,
|
| 34 |
+
cap_feat_dim: int = 2048,
|
| 35 |
+
frequency_embedding_size: int = 256,
|
| 36 |
+
norm_eps: float = 1e-5,
|
| 37 |
+
) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
self.time_proj = Timesteps(
|
| 41 |
+
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.timestep_embedder = TimestepEmbedding(
|
| 45 |
+
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.caption_embedder = nn.Sequential(
|
| 49 |
+
RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(
|
| 53 |
+
self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
|
| 54 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 55 |
+
timestep_proj = self.time_proj(timestep).type_as(hidden_states[0])
|
| 56 |
+
time_embed = self.timestep_embedder(timestep_proj)
|
| 57 |
+
caption_embed = self.caption_embedder(encoder_hidden_states)
|
| 58 |
+
return time_embed, caption_embed
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Lumina2AttnProcessor2_0:
|
| 62 |
+
r"""
|
| 63 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
| 64 |
+
used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self):
|
| 68 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 69 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 70 |
+
|
| 71 |
+
def __call__(
|
| 72 |
+
self,
|
| 73 |
+
attn: Attention,
|
| 74 |
+
hidden_states: torch.Tensor,
|
| 75 |
+
encoder_hidden_states: torch.Tensor,
|
| 76 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 77 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 78 |
+
base_sequence_length: Optional[int] = None,
|
| 79 |
+
) -> torch.Tensor:
|
| 80 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 81 |
+
|
| 82 |
+
# Get Query-Key-Value Pair
|
| 83 |
+
query = attn.to_q(hidden_states)
|
| 84 |
+
key = attn.to_k(encoder_hidden_states)
|
| 85 |
+
value = attn.to_v(encoder_hidden_states)
|
| 86 |
+
|
| 87 |
+
query_dim = query.shape[-1]
|
| 88 |
+
inner_dim = key.shape[-1]
|
| 89 |
+
head_dim = query_dim // attn.heads
|
| 90 |
+
dtype = query.dtype
|
| 91 |
+
|
| 92 |
+
# Get key-value heads
|
| 93 |
+
kv_heads = inner_dim // head_dim
|
| 94 |
+
|
| 95 |
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
| 96 |
+
key = key.view(batch_size, -1, kv_heads, head_dim)
|
| 97 |
+
value = value.view(batch_size, -1, kv_heads, head_dim)
|
| 98 |
+
|
| 99 |
+
# Apply Query-Key Norm if needed
|
| 100 |
+
if attn.norm_q is not None:
|
| 101 |
+
query = attn.norm_q(query)
|
| 102 |
+
if attn.norm_k is not None:
|
| 103 |
+
key = attn.norm_k(key)
|
| 104 |
+
|
| 105 |
+
# Apply RoPE if needed
|
| 106 |
+
if image_rotary_emb is not None:
|
| 107 |
+
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
| 108 |
+
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
| 109 |
+
|
| 110 |
+
query, key = query.to(dtype), key.to(dtype)
|
| 111 |
+
|
| 112 |
+
# Apply proportional attention if true
|
| 113 |
+
if base_sequence_length is not None:
|
| 114 |
+
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
| 115 |
+
else:
|
| 116 |
+
softmax_scale = attn.scale
|
| 117 |
+
|
| 118 |
+
# perform Grouped-qurey Attention (GQA)
|
| 119 |
+
n_rep = attn.heads // kv_heads
|
| 120 |
+
if n_rep >= 1:
|
| 121 |
+
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
| 122 |
+
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
| 123 |
+
|
| 124 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 125 |
+
# (batch, heads, source_length, target_length)
|
| 126 |
+
if attention_mask is not None:
|
| 127 |
+
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
| 128 |
+
|
| 129 |
+
query = query.transpose(1, 2)
|
| 130 |
+
key = key.transpose(1, 2)
|
| 131 |
+
value = value.transpose(1, 2)
|
| 132 |
+
|
| 133 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 134 |
+
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
| 135 |
+
)
|
| 136 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 137 |
+
hidden_states = hidden_states.type_as(query)
|
| 138 |
+
|
| 139 |
+
# linear proj
|
| 140 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 141 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 142 |
+
return hidden_states
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Lumina2TransformerBlock(nn.Module):
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
dim: int,
|
| 149 |
+
num_attention_heads: int,
|
| 150 |
+
num_kv_heads: int,
|
| 151 |
+
multiple_of: int,
|
| 152 |
+
ffn_dim_multiplier: float,
|
| 153 |
+
norm_eps: float,
|
| 154 |
+
modulation: bool = True,
|
| 155 |
+
) -> None:
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.head_dim = dim // num_attention_heads
|
| 158 |
+
self.dim = dim
|
| 159 |
+
self.modulation = modulation
|
| 160 |
+
|
| 161 |
+
self.attn = Attention(
|
| 162 |
+
query_dim=dim,
|
| 163 |
+
cross_attention_dim=None,
|
| 164 |
+
dim_head=dim // num_attention_heads,
|
| 165 |
+
qk_norm="rms_norm",
|
| 166 |
+
heads=num_attention_heads,
|
| 167 |
+
kv_heads=num_kv_heads,
|
| 168 |
+
eps=1e-5,
|
| 169 |
+
bias=False,
|
| 170 |
+
out_bias=False,
|
| 171 |
+
processor=Lumina2AttnProcessor2_0(),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
self.feed_forward = LuminaFeedForward(
|
| 175 |
+
dim=dim,
|
| 176 |
+
inner_dim=4 * dim,
|
| 177 |
+
multiple_of=multiple_of,
|
| 178 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if modulation:
|
| 182 |
+
self.norm1 = LuminaRMSNormZero(
|
| 183 |
+
embedding_dim=dim,
|
| 184 |
+
norm_eps=norm_eps,
|
| 185 |
+
norm_elementwise_affine=True,
|
| 186 |
+
)
|
| 187 |
+
else:
|
| 188 |
+
self.norm1 = RMSNorm(dim, eps=norm_eps)
|
| 189 |
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
| 190 |
+
|
| 191 |
+
self.norm2 = RMSNorm(dim, eps=norm_eps)
|
| 192 |
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
| 193 |
+
|
| 194 |
+
def forward(
|
| 195 |
+
self,
|
| 196 |
+
hidden_states: torch.Tensor,
|
| 197 |
+
attention_mask: torch.Tensor,
|
| 198 |
+
image_rotary_emb: torch.Tensor,
|
| 199 |
+
temb: Optional[torch.Tensor] = None,
|
| 200 |
+
) -> torch.Tensor:
|
| 201 |
+
if self.modulation:
|
| 202 |
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
| 203 |
+
attn_output = self.attn(
|
| 204 |
+
hidden_states=norm_hidden_states,
|
| 205 |
+
encoder_hidden_states=norm_hidden_states,
|
| 206 |
+
attention_mask=attention_mask,
|
| 207 |
+
image_rotary_emb=image_rotary_emb,
|
| 208 |
+
)
|
| 209 |
+
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
| 210 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
| 211 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
| 212 |
+
else:
|
| 213 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 214 |
+
attn_output = self.attn(
|
| 215 |
+
hidden_states=norm_hidden_states,
|
| 216 |
+
encoder_hidden_states=norm_hidden_states,
|
| 217 |
+
attention_mask=attention_mask,
|
| 218 |
+
image_rotary_emb=image_rotary_emb,
|
| 219 |
+
)
|
| 220 |
+
hidden_states = hidden_states + self.norm2(attn_output)
|
| 221 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
| 222 |
+
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
| 223 |
+
|
| 224 |
+
return hidden_states
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class Lumina2RotaryPosEmbed(nn.Module):
|
| 228 |
+
def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.theta = theta
|
| 231 |
+
self.axes_dim = axes_dim
|
| 232 |
+
self.axes_lens = axes_lens
|
| 233 |
+
self.patch_size = patch_size
|
| 234 |
+
|
| 235 |
+
self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
|
| 236 |
+
|
| 237 |
+
def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
|
| 238 |
+
freqs_cis = []
|
| 239 |
+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
| 240 |
+
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
|
| 241 |
+
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype)
|
| 242 |
+
freqs_cis.append(emb)
|
| 243 |
+
return freqs_cis
|
| 244 |
+
|
| 245 |
+
def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
|
| 246 |
+
device = ids.device
|
| 247 |
+
if ids.device.type == "mps":
|
| 248 |
+
ids = ids.to("cpu")
|
| 249 |
+
|
| 250 |
+
result = []
|
| 251 |
+
for i in range(len(self.axes_dim)):
|
| 252 |
+
freqs = self.freqs_cis[i].to(ids.device)
|
| 253 |
+
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
| 254 |
+
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
| 255 |
+
return torch.cat(result, dim=-1).to(device)
|
| 256 |
+
|
| 257 |
+
def forward(
|
| 258 |
+
self,
|
| 259 |
+
attention_mask,
|
| 260 |
+
l_effective_ref_img_len,
|
| 261 |
+
l_effective_img_len,
|
| 262 |
+
ref_img_sizes,
|
| 263 |
+
img_sizes,
|
| 264 |
+
device
|
| 265 |
+
):
|
| 266 |
+
|
| 267 |
+
batch_size = len(attention_mask)
|
| 268 |
+
p = self.patch_size
|
| 269 |
+
|
| 270 |
+
encoder_seq_len = attention_mask.shape[1]
|
| 271 |
+
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
|
| 272 |
+
|
| 273 |
+
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
|
| 274 |
+
|
| 275 |
+
max_seq_len = max(seq_lengths)
|
| 276 |
+
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
| 277 |
+
max_img_len = max(l_effective_img_len)
|
| 278 |
+
|
| 279 |
+
# Create position IDs
|
| 280 |
+
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
| 281 |
+
|
| 282 |
+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
| 283 |
+
# add text position ids
|
| 284 |
+
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
|
| 285 |
+
|
| 286 |
+
pe_shift = cap_seq_len
|
| 287 |
+
pe_shift_len = cap_seq_len
|
| 288 |
+
|
| 289 |
+
if ref_img_sizes[i] is not None:
|
| 290 |
+
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
|
| 291 |
+
H, W = ref_img_size
|
| 292 |
+
ref_H_tokens, ref_W_tokens = H // p, W // p
|
| 293 |
+
assert ref_H_tokens * ref_W_tokens == ref_img_len
|
| 294 |
+
# add image position ids
|
| 295 |
+
|
| 296 |
+
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
|
| 297 |
+
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
|
| 298 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
|
| 299 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
|
| 300 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
|
| 301 |
+
|
| 302 |
+
pe_shift += max(ref_H_tokens, ref_W_tokens)
|
| 303 |
+
pe_shift_len += ref_img_len
|
| 304 |
+
|
| 305 |
+
H, W = img_sizes[i]
|
| 306 |
+
H_tokens, W_tokens = H // p, W // p
|
| 307 |
+
assert H_tokens * W_tokens == l_effective_img_len[i]
|
| 308 |
+
|
| 309 |
+
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
|
| 310 |
+
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
|
| 311 |
+
|
| 312 |
+
assert pe_shift_len + l_effective_img_len[i] == seq_len
|
| 313 |
+
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
|
| 314 |
+
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
|
| 315 |
+
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
|
| 316 |
+
|
| 317 |
+
# Get combined rotary embeddings
|
| 318 |
+
freqs_cis = self._get_freqs_cis(position_ids)
|
| 319 |
+
|
| 320 |
+
# create separate rotary embeddings for captions and images
|
| 321 |
+
cap_freqs_cis = torch.zeros(
|
| 322 |
+
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 323 |
+
)
|
| 324 |
+
ref_img_freqs_cis = torch.zeros(
|
| 325 |
+
batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 326 |
+
)
|
| 327 |
+
img_freqs_cis = torch.zeros(
|
| 328 |
+
batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
|
| 332 |
+
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
| 333 |
+
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
|
| 334 |
+
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
|
| 335 |
+
|
| 336 |
+
return (
|
| 337 |
+
cap_freqs_cis,
|
| 338 |
+
ref_img_freqs_cis,
|
| 339 |
+
img_freqs_cis,
|
| 340 |
+
freqs_cis,
|
| 341 |
+
l_effective_cap_len,
|
| 342 |
+
seq_lengths,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 347 |
+
r"""
|
| 348 |
+
Lumina2NextDiT: Diffusion model with a Transformer backbone.
|
| 349 |
+
|
| 350 |
+
Parameters:
|
| 351 |
+
sample_size (`int`): The width of the latent images. This is fixed during training since
|
| 352 |
+
it is used to learn a number of position embeddings.
|
| 353 |
+
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
|
| 354 |
+
The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
|
| 355 |
+
in_channels (`int`, *optional*, defaults to 4):
|
| 356 |
+
The number of input channels for the model. Typically, this matches the number of channels in the input
|
| 357 |
+
images.
|
| 358 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 359 |
+
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
| 360 |
+
hidden representations.
|
| 361 |
+
num_layers (`int`, *optional*, default to 32):
|
| 362 |
+
The number of layers in the model. This defines the depth of the neural network.
|
| 363 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 364 |
+
The number of attention heads in each attention layer. This parameter specifies how many separate attention
|
| 365 |
+
mechanisms are used.
|
| 366 |
+
num_kv_heads (`int`, *optional*, defaults to 8):
|
| 367 |
+
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
|
| 368 |
+
If None, it defaults to num_attention_heads.
|
| 369 |
+
multiple_of (`int`, *optional*, defaults to 256):
|
| 370 |
+
A factor that the hidden size should be a multiple of. This can help optimize certain hardware
|
| 371 |
+
configurations.
|
| 372 |
+
ffn_dim_multiplier (`float`, *optional*):
|
| 373 |
+
A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
|
| 374 |
+
the model configuration.
|
| 375 |
+
norm_eps (`float`, *optional*, defaults to 1e-5):
|
| 376 |
+
A small value added to the denominator for numerical stability in normalization layers.
|
| 377 |
+
scaling_factor (`float`, *optional*, defaults to 1.0):
|
| 378 |
+
A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
|
| 379 |
+
overall scale of the model's operations.
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
_supports_gradient_checkpointing = True
|
| 383 |
+
_no_split_modules = ["Lumina2TransformerBlock"]
|
| 384 |
+
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
|
| 385 |
+
|
| 386 |
+
@register_to_config
|
| 387 |
+
def __init__(
|
| 388 |
+
self,
|
| 389 |
+
sample_size: int = 128,
|
| 390 |
+
patch_size: int = 2,
|
| 391 |
+
in_channels: int = 16,
|
| 392 |
+
out_channels: Optional[int] = None,
|
| 393 |
+
hidden_size: int = 2304,
|
| 394 |
+
num_layers: int = 26,
|
| 395 |
+
num_refiner_layers: int = 2,
|
| 396 |
+
num_attention_heads: int = 24,
|
| 397 |
+
num_kv_heads: int = 8,
|
| 398 |
+
multiple_of: int = 256,
|
| 399 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 400 |
+
norm_eps: float = 1e-5,
|
| 401 |
+
scaling_factor: float = 1.0,
|
| 402 |
+
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
| 403 |
+
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
| 404 |
+
cap_feat_dim: int = 1024,
|
| 405 |
+
) -> None:
|
| 406 |
+
super().__init__()
|
| 407 |
+
self.out_channels = out_channels or in_channels
|
| 408 |
+
|
| 409 |
+
# 1. Positional, patch & conditional embeddings
|
| 410 |
+
self.rope_embedder = Lumina2RotaryPosEmbed(
|
| 411 |
+
theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
|
| 415 |
+
|
| 416 |
+
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
| 417 |
+
hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# 2. Noise and context refinement blocks
|
| 421 |
+
self.noise_refiner = nn.ModuleList(
|
| 422 |
+
[
|
| 423 |
+
Lumina2TransformerBlock(
|
| 424 |
+
hidden_size,
|
| 425 |
+
num_attention_heads,
|
| 426 |
+
num_kv_heads,
|
| 427 |
+
multiple_of,
|
| 428 |
+
ffn_dim_multiplier,
|
| 429 |
+
norm_eps,
|
| 430 |
+
modulation=True,
|
| 431 |
+
)
|
| 432 |
+
for _ in range(num_refiner_layers)
|
| 433 |
+
]
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
self.context_refiner = nn.ModuleList(
|
| 437 |
+
[
|
| 438 |
+
Lumina2TransformerBlock(
|
| 439 |
+
hidden_size,
|
| 440 |
+
num_attention_heads,
|
| 441 |
+
num_kv_heads,
|
| 442 |
+
multiple_of,
|
| 443 |
+
ffn_dim_multiplier,
|
| 444 |
+
norm_eps,
|
| 445 |
+
modulation=False,
|
| 446 |
+
)
|
| 447 |
+
for _ in range(num_refiner_layers)
|
| 448 |
+
]
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# 3. Transformer blocks
|
| 452 |
+
self.layers = nn.ModuleList(
|
| 453 |
+
[
|
| 454 |
+
Lumina2TransformerBlock(
|
| 455 |
+
hidden_size,
|
| 456 |
+
num_attention_heads,
|
| 457 |
+
num_kv_heads,
|
| 458 |
+
multiple_of,
|
| 459 |
+
ffn_dim_multiplier,
|
| 460 |
+
norm_eps,
|
| 461 |
+
modulation=True,
|
| 462 |
+
)
|
| 463 |
+
for _ in range(num_layers)
|
| 464 |
+
]
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# 4. Output norm & projection
|
| 468 |
+
self.norm_out = LuminaLayerNormContinuous(
|
| 469 |
+
embedding_dim=hidden_size,
|
| 470 |
+
conditioning_embedding_dim=min(hidden_size, 1024),
|
| 471 |
+
elementwise_affine=False,
|
| 472 |
+
eps=1e-6,
|
| 473 |
+
bias=True,
|
| 474 |
+
out_dim=patch_size * patch_size * self.out_channels,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
self.gradient_checkpointing = False
|
| 478 |
+
|
| 479 |
+
self.args_dict = {"patch_size":patch_size,"in_channels":in_channels,"hidden_size":hidden_size,
|
| 480 |
+
"num_attention_heads":num_attention_heads,"num_kv_heads":num_kv_heads,
|
| 481 |
+
"multiple_of":multiple_of,"ffn_dim_multiplier":ffn_dim_multiplier,
|
| 482 |
+
"norm_eps":norm_eps,"num_refiner_layers":num_refiner_layers}
|
| 483 |
+
|
| 484 |
+
def initialize_ref_weights(self) -> None:
|
| 485 |
+
"""
|
| 486 |
+
Initialize the weights of the model.
|
| 487 |
+
|
| 488 |
+
Uses Xavier uniform initialization for linear layers.
|
| 489 |
+
"""
|
| 490 |
+
patch_size, in_channels, hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, num_refiner_layers = \
|
| 491 |
+
(self.args_dict[k] for k in ["patch_size","in_channels","hidden_size","num_attention_heads","num_kv_heads",
|
| 492 |
+
"multiple_of","ffn_dim_multiplier","norm_eps","num_refiner_layers"])
|
| 493 |
+
with torch.no_grad():
|
| 494 |
+
self.ref_image_patch_embedder = nn.Linear(
|
| 495 |
+
in_features=self.x_embedder.in_features,
|
| 496 |
+
out_features=hidden_size,
|
| 497 |
+
)
|
| 498 |
+
self.ref_image_refiner = nn.ModuleList([
|
| 499 |
+
Lumina2TransformerBlock(
|
| 500 |
+
hidden_size,
|
| 501 |
+
num_attention_heads,
|
| 502 |
+
num_kv_heads,
|
| 503 |
+
multiple_of,
|
| 504 |
+
ffn_dim_multiplier,
|
| 505 |
+
norm_eps,
|
| 506 |
+
modulation=True
|
| 507 |
+
)
|
| 508 |
+
for _ in range(num_refiner_layers)
|
| 509 |
+
])
|
| 510 |
+
nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
|
| 511 |
+
nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
|
| 512 |
+
|
| 513 |
+
# Add learnable embeddings to distinguish different images
|
| 514 |
+
self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
|
| 515 |
+
nn.init.normal_(self.image_index_embedding, std=0.02)
|
| 516 |
+
|
| 517 |
+
def img_patch_embed_and_refine(
|
| 518 |
+
self,
|
| 519 |
+
hidden_states,
|
| 520 |
+
ref_image_hidden_states,
|
| 521 |
+
padded_img_mask,
|
| 522 |
+
padded_ref_img_mask,
|
| 523 |
+
noise_rotary_emb,
|
| 524 |
+
ref_img_rotary_emb,
|
| 525 |
+
l_effective_ref_img_len,
|
| 526 |
+
l_effective_img_len,
|
| 527 |
+
temb
|
| 528 |
+
):
|
| 529 |
+
batch_size = len(hidden_states)
|
| 530 |
+
max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
|
| 531 |
+
|
| 532 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 533 |
+
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
|
| 534 |
+
|
| 535 |
+
for i in range(batch_size):
|
| 536 |
+
shift = 0
|
| 537 |
+
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
|
| 538 |
+
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
|
| 539 |
+
shift += ref_img_len
|
| 540 |
+
|
| 541 |
+
for layer in self.noise_refiner:
|
| 542 |
+
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
| 543 |
+
|
| 544 |
+
flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
|
| 545 |
+
num_ref_images = len(flat_l_effective_ref_img_len)
|
| 546 |
+
max_ref_img_len = max(flat_l_effective_ref_img_len)
|
| 547 |
+
|
| 548 |
+
batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
|
| 549 |
+
batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
|
| 550 |
+
batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
|
| 551 |
+
batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
|
| 552 |
+
|
| 553 |
+
# sequence of ref imgs to batch
|
| 554 |
+
idx = 0
|
| 555 |
+
for i in range(batch_size):
|
| 556 |
+
shift = 0
|
| 557 |
+
for ref_img_len in l_effective_ref_img_len[i]:
|
| 558 |
+
batch_ref_img_mask[idx, :ref_img_len] = True
|
| 559 |
+
batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
|
| 560 |
+
batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
|
| 561 |
+
batch_temb[idx] = temb[i]
|
| 562 |
+
shift += ref_img_len
|
| 563 |
+
idx += 1
|
| 564 |
+
|
| 565 |
+
# refine ref imgs separately
|
| 566 |
+
for layer in self.ref_image_refiner:
|
| 567 |
+
batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
|
| 568 |
+
|
| 569 |
+
# batch of ref imgs to sequence
|
| 570 |
+
idx = 0
|
| 571 |
+
for i in range(batch_size):
|
| 572 |
+
shift = 0
|
| 573 |
+
for ref_img_len in l_effective_ref_img_len[i]:
|
| 574 |
+
ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
|
| 575 |
+
shift += ref_img_len
|
| 576 |
+
idx += 1
|
| 577 |
+
|
| 578 |
+
combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
|
| 579 |
+
for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
|
| 580 |
+
combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
|
| 581 |
+
combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
|
| 582 |
+
|
| 583 |
+
return combined_img_hidden_states
|
| 584 |
+
|
| 585 |
+
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
|
| 586 |
+
batch_size = len(hidden_states)
|
| 587 |
+
p = self.config.patch_size
|
| 588 |
+
device = hidden_states[0].device
|
| 589 |
+
|
| 590 |
+
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
| 591 |
+
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
|
| 592 |
+
|
| 593 |
+
if ref_image_hidden_states is not None:
|
| 594 |
+
ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
|
| 595 |
+
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
|
| 596 |
+
else:
|
| 597 |
+
ref_img_sizes = [None for _ in range(batch_size)]
|
| 598 |
+
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
|
| 599 |
+
|
| 600 |
+
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
| 601 |
+
max_img_len = max(l_effective_img_len)
|
| 602 |
+
|
| 603 |
+
# ref image patch embeddings
|
| 604 |
+
flat_ref_img_hidden_states = []
|
| 605 |
+
for i in range(batch_size):
|
| 606 |
+
if ref_img_sizes[i] is not None:
|
| 607 |
+
imgs = []
|
| 608 |
+
for ref_img in ref_image_hidden_states[i]:
|
| 609 |
+
C, H, W = ref_img.size()
|
| 610 |
+
ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
| 611 |
+
imgs.append(ref_img)
|
| 612 |
+
|
| 613 |
+
img = torch.cat(imgs, dim=0)
|
| 614 |
+
flat_ref_img_hidden_states.append(img)
|
| 615 |
+
else:
|
| 616 |
+
flat_ref_img_hidden_states.append(None)
|
| 617 |
+
|
| 618 |
+
# image patch embeddings
|
| 619 |
+
flat_hidden_states = []
|
| 620 |
+
for i in range(batch_size):
|
| 621 |
+
img = hidden_states[i]
|
| 622 |
+
C, H, W = img.size()
|
| 623 |
+
|
| 624 |
+
img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
| 625 |
+
flat_hidden_states.append(img)
|
| 626 |
+
|
| 627 |
+
padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
|
| 628 |
+
padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
|
| 629 |
+
for i in range(batch_size):
|
| 630 |
+
if ref_img_sizes[i] is not None:
|
| 631 |
+
padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
|
| 632 |
+
padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
|
| 633 |
+
|
| 634 |
+
padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
|
| 635 |
+
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
|
| 636 |
+
for i in range(batch_size):
|
| 637 |
+
padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
|
| 638 |
+
padded_img_mask[i, :l_effective_img_len[i]] = True
|
| 639 |
+
|
| 640 |
+
return (
|
| 641 |
+
padded_hidden_states,
|
| 642 |
+
padded_ref_img_hidden_states,
|
| 643 |
+
padded_img_mask,
|
| 644 |
+
padded_ref_img_mask,
|
| 645 |
+
l_effective_ref_img_len,
|
| 646 |
+
l_effective_img_len,
|
| 647 |
+
ref_img_sizes,
|
| 648 |
+
img_sizes,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
def forward(
|
| 652 |
+
self,
|
| 653 |
+
hidden_states: torch.Tensor,
|
| 654 |
+
timestep: torch.Tensor,
|
| 655 |
+
encoder_hidden_states: torch.Tensor,
|
| 656 |
+
encoder_attention_mask: torch.Tensor,
|
| 657 |
+
ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
|
| 658 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 659 |
+
return_dict: bool = True,
|
| 660 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 661 |
+
if attention_kwargs is not None:
|
| 662 |
+
attention_kwargs = attention_kwargs.copy()
|
| 663 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 664 |
+
else:
|
| 665 |
+
lora_scale = 1.0
|
| 666 |
+
|
| 667 |
+
if USE_PEFT_BACKEND:
|
| 668 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 669 |
+
scale_lora_layers(self, lora_scale)
|
| 670 |
+
else:
|
| 671 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 672 |
+
logger.warning(
|
| 673 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
# 1. Condition, positional & patch embedding
|
| 677 |
+
batch_size = len(hidden_states)
|
| 678 |
+
is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
|
| 679 |
+
|
| 680 |
+
if is_hidden_states_tensor:
|
| 681 |
+
assert hidden_states.ndim == 4
|
| 682 |
+
hidden_states = [_hidden_states for _hidden_states in hidden_states]
|
| 683 |
+
|
| 684 |
+
device = hidden_states[0].device
|
| 685 |
+
|
| 686 |
+
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
|
| 687 |
+
|
| 688 |
+
(
|
| 689 |
+
hidden_states,
|
| 690 |
+
ref_image_hidden_states,
|
| 691 |
+
img_mask,
|
| 692 |
+
ref_img_mask,
|
| 693 |
+
l_effective_ref_img_len,
|
| 694 |
+
l_effective_img_len,
|
| 695 |
+
ref_img_sizes,
|
| 696 |
+
img_sizes,
|
| 697 |
+
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
|
| 698 |
+
|
| 699 |
+
(
|
| 700 |
+
context_rotary_emb,
|
| 701 |
+
ref_img_rotary_emb,
|
| 702 |
+
noise_rotary_emb,
|
| 703 |
+
rotary_emb,
|
| 704 |
+
encoder_seq_lengths,
|
| 705 |
+
seq_lengths,
|
| 706 |
+
) = self.rope_embedder(
|
| 707 |
+
encoder_attention_mask,
|
| 708 |
+
l_effective_ref_img_len,
|
| 709 |
+
l_effective_img_len,
|
| 710 |
+
ref_img_sizes,
|
| 711 |
+
img_sizes,
|
| 712 |
+
device,
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
# 2. Context & noise refinement
|
| 716 |
+
for layer in self.context_refiner:
|
| 717 |
+
encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb)
|
| 718 |
+
|
| 719 |
+
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
| 720 |
+
hidden_states,
|
| 721 |
+
ref_image_hidden_states,
|
| 722 |
+
img_mask,
|
| 723 |
+
ref_img_mask,
|
| 724 |
+
noise_rotary_emb,
|
| 725 |
+
ref_img_rotary_emb,
|
| 726 |
+
l_effective_ref_img_len,
|
| 727 |
+
l_effective_img_len,
|
| 728 |
+
temb,
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
# 3. Joint Transformer blocks
|
| 732 |
+
max_seq_len = max(seq_lengths)
|
| 733 |
+
use_mask = len(set(seq_lengths)) > 1
|
| 734 |
+
|
| 735 |
+
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
| 736 |
+
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
|
| 737 |
+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
| 738 |
+
attention_mask[i, :seq_len] = True
|
| 739 |
+
joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len]
|
| 740 |
+
joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
|
| 741 |
+
|
| 742 |
+
hidden_states = joint_hidden_states
|
| 743 |
+
|
| 744 |
+
for layer in self.layers:
|
| 745 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 746 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 747 |
+
layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb
|
| 748 |
+
)
|
| 749 |
+
else:
|
| 750 |
+
hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb)
|
| 751 |
+
|
| 752 |
+
# 4. Output norm & projection
|
| 753 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 754 |
+
|
| 755 |
+
# 5. Unpatchify
|
| 756 |
+
p = self.config.patch_size
|
| 757 |
+
output = []
|
| 758 |
+
for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
|
| 759 |
+
height, width = img_size
|
| 760 |
+
output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
|
| 761 |
+
if is_hidden_states_tensor:
|
| 762 |
+
output = torch.stack(output, dim=0)
|
| 763 |
+
|
| 764 |
+
if USE_PEFT_BACKEND:
|
| 765 |
+
# remove `lora_scale` from each PEFT layer
|
| 766 |
+
unscale_lora_layers(self, lora_scale)
|
| 767 |
+
|
| 768 |
+
if not return_dict:
|
| 769 |
+
return (output,)
|
| 770 |
+
return Transformer2DModelOutput(sample=output)
|
star/models/pixel_decoder/transformer_lumina2_seq.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from einops import rearrange
|
| 22 |
+
from diffusers.models.transformers.transformer_lumina2 import *
|
| 23 |
+
from einops import repeat
|
| 24 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 25 |
+
import itertools
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
hidden_size: int = 4096,
|
| 34 |
+
cap_feat_dim: int = 2048,
|
| 35 |
+
frequency_embedding_size: int = 256,
|
| 36 |
+
norm_eps: float = 1e-5,
|
| 37 |
+
) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
self.time_proj = Timesteps(
|
| 41 |
+
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.timestep_embedder = TimestepEmbedding(
|
| 45 |
+
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.caption_embedder = nn.Sequential(
|
| 49 |
+
RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(
|
| 53 |
+
self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
|
| 54 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 55 |
+
timestep_proj = self.time_proj(timestep).type_as(hidden_states)
|
| 56 |
+
time_embed = self.timestep_embedder(timestep_proj)
|
| 57 |
+
caption_embed = self.caption_embedder(encoder_hidden_states)
|
| 58 |
+
return time_embed, caption_embed
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Lumina2AttnProcessor2_0:
|
| 62 |
+
r"""
|
| 63 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
| 64 |
+
used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self):
|
| 68 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 69 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 70 |
+
|
| 71 |
+
def __call__(
|
| 72 |
+
self,
|
| 73 |
+
attn: Attention,
|
| 74 |
+
hidden_states: torch.Tensor,
|
| 75 |
+
encoder_hidden_states: torch.Tensor,
|
| 76 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 77 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 78 |
+
base_sequence_length: Optional[int] = None,
|
| 79 |
+
) -> torch.Tensor:
|
| 80 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 81 |
+
|
| 82 |
+
# Get Query-Key-Value Pair
|
| 83 |
+
query = attn.to_q(hidden_states)
|
| 84 |
+
key = attn.to_k(encoder_hidden_states)
|
| 85 |
+
value = attn.to_v(encoder_hidden_states)
|
| 86 |
+
|
| 87 |
+
query_dim = query.shape[-1]
|
| 88 |
+
inner_dim = key.shape[-1]
|
| 89 |
+
head_dim = query_dim // attn.heads
|
| 90 |
+
dtype = query.dtype
|
| 91 |
+
|
| 92 |
+
# Get key-value heads
|
| 93 |
+
kv_heads = inner_dim // head_dim
|
| 94 |
+
|
| 95 |
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
| 96 |
+
key = key.view(batch_size, -1, kv_heads, head_dim)
|
| 97 |
+
value = value.view(batch_size, -1, kv_heads, head_dim)
|
| 98 |
+
|
| 99 |
+
# Apply Query-Key Norm if needed
|
| 100 |
+
if attn.norm_q is not None:
|
| 101 |
+
query = attn.norm_q(query)
|
| 102 |
+
if attn.norm_k is not None:
|
| 103 |
+
key = attn.norm_k(key)
|
| 104 |
+
|
| 105 |
+
# Apply RoPE if needed
|
| 106 |
+
if image_rotary_emb is not None:
|
| 107 |
+
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
| 108 |
+
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
| 109 |
+
|
| 110 |
+
query, key = query.to(dtype), key.to(dtype)
|
| 111 |
+
|
| 112 |
+
# Apply proportional attention if true
|
| 113 |
+
if base_sequence_length is not None:
|
| 114 |
+
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
| 115 |
+
else:
|
| 116 |
+
softmax_scale = attn.scale
|
| 117 |
+
|
| 118 |
+
# perform Grouped-qurey Attention (GQA)
|
| 119 |
+
n_rep = attn.heads // kv_heads
|
| 120 |
+
if n_rep >= 1:
|
| 121 |
+
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
| 122 |
+
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
| 123 |
+
|
| 124 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 125 |
+
# (batch, heads, source_length, target_length)
|
| 126 |
+
if attention_mask is not None:
|
| 127 |
+
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
| 128 |
+
|
| 129 |
+
query = query.transpose(1, 2)
|
| 130 |
+
key = key.transpose(1, 2)
|
| 131 |
+
value = value.transpose(1, 2)
|
| 132 |
+
|
| 133 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 134 |
+
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
| 135 |
+
)
|
| 136 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 137 |
+
hidden_states = hidden_states.type_as(query)
|
| 138 |
+
|
| 139 |
+
# linear proj
|
| 140 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 141 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 142 |
+
return hidden_states
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Lumina2TransformerBlock(nn.Module):
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
dim: int,
|
| 149 |
+
num_attention_heads: int,
|
| 150 |
+
num_kv_heads: int,
|
| 151 |
+
multiple_of: int,
|
| 152 |
+
ffn_dim_multiplier: float,
|
| 153 |
+
norm_eps: float,
|
| 154 |
+
modulation: bool = True,
|
| 155 |
+
) -> None:
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.head_dim = dim // num_attention_heads
|
| 158 |
+
self.dim = dim
|
| 159 |
+
self.modulation = modulation
|
| 160 |
+
|
| 161 |
+
self.attn = Attention(
|
| 162 |
+
query_dim=dim,
|
| 163 |
+
cross_attention_dim=None,
|
| 164 |
+
dim_head=dim // num_attention_heads,
|
| 165 |
+
qk_norm="rms_norm",
|
| 166 |
+
heads=num_attention_heads,
|
| 167 |
+
kv_heads=num_kv_heads,
|
| 168 |
+
eps=1e-5,
|
| 169 |
+
bias=False,
|
| 170 |
+
out_bias=False,
|
| 171 |
+
processor=Lumina2AttnProcessor2_0(),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
self.feed_forward = LuminaFeedForward(
|
| 175 |
+
dim=dim,
|
| 176 |
+
inner_dim=4 * dim,
|
| 177 |
+
multiple_of=multiple_of,
|
| 178 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if modulation:
|
| 182 |
+
self.norm1 = LuminaRMSNormZero(
|
| 183 |
+
embedding_dim=dim,
|
| 184 |
+
norm_eps=norm_eps,
|
| 185 |
+
norm_elementwise_affine=True,
|
| 186 |
+
)
|
| 187 |
+
else:
|
| 188 |
+
self.norm1 = RMSNorm(dim, eps=norm_eps)
|
| 189 |
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
| 190 |
+
|
| 191 |
+
self.norm2 = RMSNorm(dim, eps=norm_eps)
|
| 192 |
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
| 193 |
+
|
| 194 |
+
def forward(
|
| 195 |
+
self,
|
| 196 |
+
hidden_states: torch.Tensor,
|
| 197 |
+
attention_mask: torch.Tensor,
|
| 198 |
+
image_rotary_emb: torch.Tensor,
|
| 199 |
+
temb: Optional[torch.Tensor] = None,
|
| 200 |
+
) -> torch.Tensor:
|
| 201 |
+
if self.modulation:
|
| 202 |
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
| 203 |
+
attn_output = self.attn(
|
| 204 |
+
hidden_states=norm_hidden_states,
|
| 205 |
+
encoder_hidden_states=norm_hidden_states,
|
| 206 |
+
attention_mask=attention_mask,
|
| 207 |
+
image_rotary_emb=image_rotary_emb,
|
| 208 |
+
)
|
| 209 |
+
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
| 210 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
| 211 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
| 212 |
+
else:
|
| 213 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 214 |
+
attn_output = self.attn(
|
| 215 |
+
hidden_states=norm_hidden_states,
|
| 216 |
+
encoder_hidden_states=norm_hidden_states,
|
| 217 |
+
attention_mask=attention_mask,
|
| 218 |
+
image_rotary_emb=image_rotary_emb,
|
| 219 |
+
)
|
| 220 |
+
hidden_states = hidden_states + self.norm2(attn_output)
|
| 221 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
| 222 |
+
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
| 223 |
+
|
| 224 |
+
return hidden_states
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class Lumina2RotaryPosEmbed(nn.Module):
|
| 228 |
+
def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.theta = theta
|
| 231 |
+
self.axes_dim = axes_dim
|
| 232 |
+
self.axes_lens = axes_lens
|
| 233 |
+
self.patch_size = patch_size
|
| 234 |
+
|
| 235 |
+
self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
|
| 236 |
+
|
| 237 |
+
def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
|
| 238 |
+
freqs_cis = []
|
| 239 |
+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
| 240 |
+
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
|
| 241 |
+
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype)
|
| 242 |
+
freqs_cis.append(emb)
|
| 243 |
+
return freqs_cis
|
| 244 |
+
|
| 245 |
+
def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
|
| 246 |
+
device = ids.device
|
| 247 |
+
if ids.device.type == "mps":
|
| 248 |
+
ids = ids.to("cpu")
|
| 249 |
+
|
| 250 |
+
result = []
|
| 251 |
+
for i in range(len(self.axes_dim)):
|
| 252 |
+
freqs = self.freqs_cis[i].to(ids.device)
|
| 253 |
+
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
| 254 |
+
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
| 255 |
+
return torch.cat(result, dim=-1).to(device)
|
| 256 |
+
|
| 257 |
+
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
|
| 258 |
+
batch_size, channels, height, width = hidden_states.shape
|
| 259 |
+
p = self.patch_size
|
| 260 |
+
post_patch_height, post_patch_width = height // p, width // p
|
| 261 |
+
image_seq_len = post_patch_height * post_patch_width
|
| 262 |
+
device = hidden_states.device
|
| 263 |
+
|
| 264 |
+
encoder_seq_len = attention_mask.shape[1]
|
| 265 |
+
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
|
| 266 |
+
seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
|
| 267 |
+
max_seq_len = max(seq_lengths)
|
| 268 |
+
|
| 269 |
+
# Create position IDs
|
| 270 |
+
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
| 271 |
+
|
| 272 |
+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
| 273 |
+
# add caption position ids
|
| 274 |
+
position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device)
|
| 275 |
+
position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len
|
| 276 |
+
|
| 277 |
+
# add image position ids
|
| 278 |
+
row_ids = (
|
| 279 |
+
torch.arange(post_patch_height, dtype=torch.int32, device=device)
|
| 280 |
+
.view(-1, 1)
|
| 281 |
+
.repeat(1, post_patch_width)
|
| 282 |
+
.flatten()
|
| 283 |
+
)
|
| 284 |
+
col_ids = (
|
| 285 |
+
torch.arange(post_patch_width, dtype=torch.int32, device=device)
|
| 286 |
+
.view(1, -1)
|
| 287 |
+
.repeat(post_patch_height, 1)
|
| 288 |
+
.flatten()
|
| 289 |
+
)
|
| 290 |
+
position_ids[i, cap_seq_len:seq_len, 1] = row_ids
|
| 291 |
+
position_ids[i, cap_seq_len:seq_len, 2] = col_ids
|
| 292 |
+
|
| 293 |
+
# Get combined rotary embeddings
|
| 294 |
+
freqs_cis = self._get_freqs_cis(position_ids)
|
| 295 |
+
|
| 296 |
+
# create separate rotary embeddings for captions and images
|
| 297 |
+
cap_freqs_cis = torch.zeros(
|
| 298 |
+
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 299 |
+
)
|
| 300 |
+
img_freqs_cis = torch.zeros(
|
| 301 |
+
batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
| 305 |
+
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
| 306 |
+
img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len]
|
| 307 |
+
|
| 308 |
+
# image patch embeddings
|
| 309 |
+
hidden_states = (
|
| 310 |
+
hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p)
|
| 311 |
+
.permute(0, 2, 4, 3, 5, 1)
|
| 312 |
+
.flatten(3)
|
| 313 |
+
.flatten(1, 2)
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 320 |
+
r"""
|
| 321 |
+
Lumina2NextDiT: Diffusion model with a Transformer backbone.
|
| 322 |
+
|
| 323 |
+
Parameters:
|
| 324 |
+
sample_size (`int`): The width of the latent images. This is fixed during training since
|
| 325 |
+
it is used to learn a number of position embeddings.
|
| 326 |
+
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
|
| 327 |
+
The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
|
| 328 |
+
in_channels (`int`, *optional*, defaults to 4):
|
| 329 |
+
The number of input channels for the model. Typically, this matches the number of channels in the input
|
| 330 |
+
images.
|
| 331 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 332 |
+
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
| 333 |
+
hidden representations.
|
| 334 |
+
num_layers (`int`, *optional*, default to 32):
|
| 335 |
+
The number of layers in the model. This defines the depth of the neural network.
|
| 336 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 337 |
+
The number of attention heads in each attention layer. This parameter specifies how many separate attention
|
| 338 |
+
mechanisms are used.
|
| 339 |
+
num_kv_heads (`int`, *optional*, defaults to 8):
|
| 340 |
+
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
|
| 341 |
+
If None, it defaults to num_attention_heads.
|
| 342 |
+
multiple_of (`int`, *optional*, defaults to 256):
|
| 343 |
+
A factor that the hidden size should be a multiple of. This can help optimize certain hardware
|
| 344 |
+
configurations.
|
| 345 |
+
ffn_dim_multiplier (`float`, *optional*):
|
| 346 |
+
A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
|
| 347 |
+
the model configuration.
|
| 348 |
+
norm_eps (`float`, *optional*, defaults to 1e-5):
|
| 349 |
+
A small value added to the denominator for numerical stability in normalization layers.
|
| 350 |
+
scaling_factor (`float`, *optional*, defaults to 1.0):
|
| 351 |
+
A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
|
| 352 |
+
overall scale of the model's operations.
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
_supports_gradient_checkpointing = True
|
| 356 |
+
_no_split_modules = ["Lumina2TransformerBlock"]
|
| 357 |
+
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
|
| 358 |
+
|
| 359 |
+
@register_to_config
|
| 360 |
+
def __init__(
|
| 361 |
+
self,
|
| 362 |
+
sample_size: int = 128,
|
| 363 |
+
patch_size: int = 2,
|
| 364 |
+
in_channels: int = 16,
|
| 365 |
+
out_channels: Optional[int] = None,
|
| 366 |
+
hidden_size: int = 2304,
|
| 367 |
+
num_layers: int = 26,
|
| 368 |
+
num_refiner_layers: int = 2,
|
| 369 |
+
num_attention_heads: int = 24,
|
| 370 |
+
num_kv_heads: int = 8,
|
| 371 |
+
multiple_of: int = 256,
|
| 372 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 373 |
+
norm_eps: float = 1e-5,
|
| 374 |
+
scaling_factor: float = 1.0,
|
| 375 |
+
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
| 376 |
+
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
| 377 |
+
cap_feat_dim: int = 1024,
|
| 378 |
+
) -> None:
|
| 379 |
+
super().__init__()
|
| 380 |
+
self.out_channels = out_channels or in_channels
|
| 381 |
+
|
| 382 |
+
# 1. Positional, patch & conditional embeddings
|
| 383 |
+
self.rope_embedder = Lumina2RotaryPosEmbed(
|
| 384 |
+
theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
|
| 388 |
+
|
| 389 |
+
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
| 390 |
+
hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# 2. Noise and context refinement blocks
|
| 394 |
+
self.noise_refiner = nn.ModuleList(
|
| 395 |
+
[
|
| 396 |
+
Lumina2TransformerBlock(
|
| 397 |
+
hidden_size,
|
| 398 |
+
num_attention_heads,
|
| 399 |
+
num_kv_heads,
|
| 400 |
+
multiple_of,
|
| 401 |
+
ffn_dim_multiplier,
|
| 402 |
+
norm_eps,
|
| 403 |
+
modulation=True,
|
| 404 |
+
)
|
| 405 |
+
for _ in range(num_refiner_layers)
|
| 406 |
+
]
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
self.context_refiner = nn.ModuleList(
|
| 410 |
+
[
|
| 411 |
+
Lumina2TransformerBlock(
|
| 412 |
+
hidden_size,
|
| 413 |
+
num_attention_heads,
|
| 414 |
+
num_kv_heads,
|
| 415 |
+
multiple_of,
|
| 416 |
+
ffn_dim_multiplier,
|
| 417 |
+
norm_eps,
|
| 418 |
+
modulation=False,
|
| 419 |
+
)
|
| 420 |
+
for _ in range(num_refiner_layers)
|
| 421 |
+
]
|
| 422 |
+
)
|
| 423 |
+
self.ori_inp_dit = "none"
|
| 424 |
+
self.ori_inp_refiner = None
|
| 425 |
+
|
| 426 |
+
# 3. Transformer blocks
|
| 427 |
+
self.layers = nn.ModuleList(
|
| 428 |
+
[
|
| 429 |
+
Lumina2TransformerBlock(
|
| 430 |
+
hidden_size,
|
| 431 |
+
num_attention_heads,
|
| 432 |
+
num_kv_heads,
|
| 433 |
+
multiple_of,
|
| 434 |
+
ffn_dim_multiplier,
|
| 435 |
+
norm_eps,
|
| 436 |
+
modulation=True,
|
| 437 |
+
)
|
| 438 |
+
for _ in range(num_layers)
|
| 439 |
+
]
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# 4. Output norm & projection
|
| 443 |
+
self.norm_out = LuminaLayerNormContinuous(
|
| 444 |
+
embedding_dim=hidden_size,
|
| 445 |
+
conditioning_embedding_dim=min(hidden_size, 1024),
|
| 446 |
+
elementwise_affine=False,
|
| 447 |
+
eps=1e-6,
|
| 448 |
+
bias=True,
|
| 449 |
+
out_dim=patch_size * patch_size * self.out_channels,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
self.gradient_checkpointing = False
|
| 453 |
+
|
| 454 |
+
def forward(
|
| 455 |
+
self,
|
| 456 |
+
hidden_states: torch.Tensor,
|
| 457 |
+
timestep: torch.Tensor,
|
| 458 |
+
encoder_hidden_states: torch.Tensor,
|
| 459 |
+
encoder_attention_mask: torch.Tensor,
|
| 460 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 461 |
+
return_dict: bool = True,
|
| 462 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 463 |
+
if attention_kwargs is not None:
|
| 464 |
+
attention_kwargs = attention_kwargs.copy()
|
| 465 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 466 |
+
else:
|
| 467 |
+
lora_scale = 1.0
|
| 468 |
+
|
| 469 |
+
if USE_PEFT_BACKEND:
|
| 470 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 471 |
+
scale_lora_layers(self, lora_scale)
|
| 472 |
+
else:
|
| 473 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 474 |
+
logger.warning(
|
| 475 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# 1. Condition, positional & patch embedding
|
| 479 |
+
batch_size, _, height, width = hidden_states.shape
|
| 480 |
+
|
| 481 |
+
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
|
| 482 |
+
|
| 483 |
+
(
|
| 484 |
+
hidden_states,
|
| 485 |
+
context_rotary_emb,
|
| 486 |
+
noise_rotary_emb,
|
| 487 |
+
rotary_emb,
|
| 488 |
+
encoder_seq_lengths,
|
| 489 |
+
seq_lengths,
|
| 490 |
+
) = self.rope_embedder(hidden_states, encoder_attention_mask)
|
| 491 |
+
|
| 492 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 493 |
+
|
| 494 |
+
# 2. Context & noise refinement
|
| 495 |
+
for layer in self.context_refiner:
|
| 496 |
+
encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb)
|
| 497 |
+
|
| 498 |
+
for layer in self.noise_refiner:
|
| 499 |
+
hidden_states = layer(hidden_states, None, noise_rotary_emb, temb)
|
| 500 |
+
|
| 501 |
+
if self.ori_inp_dit!="none" and self.ori_inp_refiner is not None:
|
| 502 |
+
single_img_length = hidden_states.shape[1]//2
|
| 503 |
+
initial_part = hidden_states[:, :single_img_length]
|
| 504 |
+
refined_part = self.ori_inp_refiner(hidden_states[:, single_img_length:])
|
| 505 |
+
updated_hidden_states = torch.cat((initial_part, refined_part), dim=1)
|
| 506 |
+
hidden_states = updated_hidden_states
|
| 507 |
+
|
| 508 |
+
# 3. Joint Transformer blocks
|
| 509 |
+
max_seq_len = max(seq_lengths)
|
| 510 |
+
use_mask = len(set(seq_lengths)) > 1
|
| 511 |
+
|
| 512 |
+
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
| 513 |
+
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
|
| 514 |
+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
| 515 |
+
attention_mask[i, :seq_len] = True
|
| 516 |
+
joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len]
|
| 517 |
+
joint_hidden_states[i, encoder_seq_len:seq_len] = hidden_states[i]
|
| 518 |
+
|
| 519 |
+
hidden_states = joint_hidden_states
|
| 520 |
+
|
| 521 |
+
for layer in self.layers:
|
| 522 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 523 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 524 |
+
layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb
|
| 525 |
+
)
|
| 526 |
+
else:
|
| 527 |
+
hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb)
|
| 528 |
+
|
| 529 |
+
# 4. Output norm & projection
|
| 530 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 531 |
+
|
| 532 |
+
# 5. Unpatchify
|
| 533 |
+
p = self.config.patch_size
|
| 534 |
+
output = []
|
| 535 |
+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
| 536 |
+
output.append(
|
| 537 |
+
hidden_states[i][encoder_seq_len:seq_len]
|
| 538 |
+
.view(height // p, width // p, p, p, self.out_channels)
|
| 539 |
+
.permute(4, 0, 2, 1, 3)
|
| 540 |
+
.flatten(3, 4)
|
| 541 |
+
.flatten(1, 2)
|
| 542 |
+
)
|
| 543 |
+
output = torch.stack(output, dim=0)
|
| 544 |
+
|
| 545 |
+
if USE_PEFT_BACKEND:
|
| 546 |
+
# remove `lora_scale` from each PEFT layer
|
| 547 |
+
unscale_lora_layers(self, lora_scale)
|
| 548 |
+
|
| 549 |
+
if not return_dict:
|
| 550 |
+
return (output,)
|
| 551 |
+
return Transformer2DModelOutput(sample=output)
|
star/models/pixel_encoder/vq_model.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import List
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from deepspeed.utils import logger
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class ModelArgs:
|
| 12 |
+
codebook_size: int = 16384
|
| 13 |
+
codebook_embed_dim: int = 8
|
| 14 |
+
codebook_l2_norm: bool = False
|
| 15 |
+
codebook_show_usage: bool = True
|
| 16 |
+
commit_loss_beta: float = 0.25
|
| 17 |
+
entropy_loss_ratio: float = 0.0
|
| 18 |
+
encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
|
| 19 |
+
decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
|
| 20 |
+
z_channels: int = 256
|
| 21 |
+
dropout_p: float = 0.0
|
| 22 |
+
num_res_blocks: int = 2
|
| 23 |
+
ch: int=128
|
| 24 |
+
attn_num_heads: int = 1
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class VQModel(nn.Module):
|
| 28 |
+
def __init__(self, config: ModelArgs):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.config = config
|
| 31 |
+
self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p, num_res_blocks=config.num_res_blocks, ch=config.ch, attn_num_heads=config.attn_num_heads)
|
| 32 |
+
self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p, num_res_blocks=config.num_res_blocks, ch=config.ch, attn_num_heads=config.attn_num_heads)
|
| 33 |
+
|
| 34 |
+
self.quantize = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
|
| 35 |
+
config.commit_loss_beta, config.entropy_loss_ratio,
|
| 36 |
+
config.codebook_l2_norm, config.codebook_show_usage)
|
| 37 |
+
self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
|
| 38 |
+
self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def encode(self, x):
|
| 42 |
+
h = self.encoder(x)
|
| 43 |
+
h = self.quant_conv(h)
|
| 44 |
+
|
| 45 |
+
quant, emb_loss, info = self.quantize(h)
|
| 46 |
+
return quant, emb_loss, info
|
| 47 |
+
|
| 48 |
+
def decode(self, quant):
|
| 49 |
+
quant = self.post_quant_conv(quant)
|
| 50 |
+
dec = self.decoder(quant)
|
| 51 |
+
return dec
|
| 52 |
+
|
| 53 |
+
def decode_code(self, code_b, shape=None, channel_first=True):
|
| 54 |
+
quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
|
| 55 |
+
dec = self.decode(quant_b)
|
| 56 |
+
return dec # [B, C, H, W]
|
| 57 |
+
|
| 58 |
+
def forward(self, input):
|
| 59 |
+
quant, diff, _ = self.encode(input)
|
| 60 |
+
dec = self.decode(quant)
|
| 61 |
+
return dec, diff
|
| 62 |
+
|
| 63 |
+
def get_codebook_entry(self, code_b, shape=None, channel_first=True):
|
| 64 |
+
quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
|
| 65 |
+
return quant_b
|
| 66 |
+
|
| 67 |
+
def image_to_seq(self, image):
|
| 68 |
+
quant, _, [_, _, indices] = self.encode(image)
|
| 69 |
+
batch_size = image.shape[0]
|
| 70 |
+
return indices.reshape(batch_size, -1)
|
| 71 |
+
|
| 72 |
+
def seq_to_image(self, tokens):
|
| 73 |
+
tokens = torch.clamp(tokens, min=0)
|
| 74 |
+
assert tokens.size(-1) == self.config.num_tokens, (
|
| 75 |
+
f"can not generate the image as the token length is {tokens.size(-1)} != {self.config.num_tokens}"
|
| 76 |
+
)
|
| 77 |
+
bs, HW = tokens.shape
|
| 78 |
+
H = W = int(math.sqrt(HW))
|
| 79 |
+
images = self.decode_code(tokens, shape=[bs, self.config.codebook_embed_dim, H, W])
|
| 80 |
+
images = torch.clip((images+1)/2, 0, 1)
|
| 81 |
+
images = torch.permute(images, [0, 2, 3, 1])
|
| 82 |
+
|
| 83 |
+
return images
|
| 84 |
+
|
| 85 |
+
def load_trained_weights(self, pretrained=None):
|
| 86 |
+
device_index = torch.cuda.current_device()
|
| 87 |
+
device = torch.device(f'cuda:{device_index}')
|
| 88 |
+
weights = torch.load(pretrained, map_location=device)
|
| 89 |
+
self.load_state_dict(weights, strict=True)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class Encoder(nn.Module):
|
| 93 |
+
def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2,
|
| 94 |
+
norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256, attn_num_heads=1):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.num_resolutions = len(ch_mult)
|
| 97 |
+
self.num_res_blocks = num_res_blocks
|
| 98 |
+
self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
|
| 99 |
+
|
| 100 |
+
# downsampling
|
| 101 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 102 |
+
self.conv_blocks = nn.ModuleList()
|
| 103 |
+
for i_level in range(self.num_resolutions):
|
| 104 |
+
conv_block = nn.Module()
|
| 105 |
+
# res & attn
|
| 106 |
+
res_block = nn.ModuleList()
|
| 107 |
+
attn_block = nn.ModuleList()
|
| 108 |
+
block_in = ch*in_ch_mult[i_level]
|
| 109 |
+
block_out = ch*ch_mult[i_level]
|
| 110 |
+
for _ in range(self.num_res_blocks):
|
| 111 |
+
res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
|
| 112 |
+
block_in = block_out
|
| 113 |
+
if i_level == self.num_resolutions - 1:
|
| 114 |
+
attn_block.append(AttnBlock(block_in, norm_type, attn_num_heads))
|
| 115 |
+
conv_block.res = res_block
|
| 116 |
+
conv_block.attn = attn_block
|
| 117 |
+
# downsample
|
| 118 |
+
if i_level != self.num_resolutions-1:
|
| 119 |
+
conv_block.downsample = Downsample(block_in, resamp_with_conv)
|
| 120 |
+
self.conv_blocks.append(conv_block)
|
| 121 |
+
|
| 122 |
+
# middle
|
| 123 |
+
self.mid = nn.ModuleList()
|
| 124 |
+
self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
|
| 125 |
+
self.mid.append(AttnBlock(block_in, norm_type=norm_type, num_heads=attn_num_heads))
|
| 126 |
+
self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
|
| 127 |
+
|
| 128 |
+
# end
|
| 129 |
+
self.norm_out = Normalize(block_in, norm_type)
|
| 130 |
+
self.conv_out = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def forward(self, x):
|
| 134 |
+
h = self.conv_in(x)
|
| 135 |
+
# downsampling
|
| 136 |
+
for i_level, block in enumerate(self.conv_blocks):
|
| 137 |
+
for i_block in range(self.num_res_blocks):
|
| 138 |
+
h = block.res[i_block](h)
|
| 139 |
+
if len(block.attn) > 0:
|
| 140 |
+
h = block.attn[i_block](h)
|
| 141 |
+
if i_level != self.num_resolutions - 1:
|
| 142 |
+
h = block.downsample(h)
|
| 143 |
+
|
| 144 |
+
# middle
|
| 145 |
+
for mid_block in self.mid:
|
| 146 |
+
h = mid_block(h)
|
| 147 |
+
|
| 148 |
+
# end
|
| 149 |
+
h = self.norm_out(h)
|
| 150 |
+
h = nonlinearity(h)
|
| 151 |
+
h = self.conv_out(h)
|
| 152 |
+
return h
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class Decoder(nn.Module):
|
| 157 |
+
def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group",
|
| 158 |
+
dropout=0.0, resamp_with_conv=True, out_channels=3, attn_num_heads=1):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.num_resolutions = len(ch_mult)
|
| 161 |
+
self.num_res_blocks = num_res_blocks
|
| 162 |
+
|
| 163 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
| 164 |
+
# z to block_in
|
| 165 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 166 |
+
|
| 167 |
+
# middle
|
| 168 |
+
self.mid = nn.ModuleList()
|
| 169 |
+
self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
|
| 170 |
+
self.mid.append(AttnBlock(block_in, norm_type=norm_type, num_heads=attn_num_heads))
|
| 171 |
+
self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
|
| 172 |
+
|
| 173 |
+
# upsampling
|
| 174 |
+
self.conv_blocks = nn.ModuleList()
|
| 175 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 176 |
+
conv_block = nn.Module()
|
| 177 |
+
# res & attn
|
| 178 |
+
res_block = nn.ModuleList()
|
| 179 |
+
attn_block = nn.ModuleList()
|
| 180 |
+
block_out = ch*ch_mult[i_level]
|
| 181 |
+
for _ in range(self.num_res_blocks + 1):
|
| 182 |
+
res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
|
| 183 |
+
block_in = block_out
|
| 184 |
+
if i_level == self.num_resolutions - 1:
|
| 185 |
+
attn_block.append(AttnBlock(block_in, norm_type, attn_num_heads))
|
| 186 |
+
conv_block.res = res_block
|
| 187 |
+
conv_block.attn = attn_block
|
| 188 |
+
# downsample
|
| 189 |
+
if i_level != 0:
|
| 190 |
+
conv_block.upsample = Upsample(block_in, resamp_with_conv)
|
| 191 |
+
self.conv_blocks.append(conv_block)
|
| 192 |
+
|
| 193 |
+
# end
|
| 194 |
+
self.norm_out = Normalize(block_in, norm_type)
|
| 195 |
+
self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
| 196 |
+
|
| 197 |
+
@property
|
| 198 |
+
def last_layer(self):
|
| 199 |
+
return self.conv_out.weight
|
| 200 |
+
|
| 201 |
+
def forward(self, z):
|
| 202 |
+
# z to block_in
|
| 203 |
+
h = self.conv_in(z)
|
| 204 |
+
|
| 205 |
+
# middle
|
| 206 |
+
for mid_block in self.mid:
|
| 207 |
+
h = mid_block(h)
|
| 208 |
+
|
| 209 |
+
# upsampling
|
| 210 |
+
for i_level, block in enumerate(self.conv_blocks):
|
| 211 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 212 |
+
h = block.res[i_block](h)
|
| 213 |
+
if len(block.attn) > 0:
|
| 214 |
+
h = block.attn[i_block](h)
|
| 215 |
+
if i_level != self.num_resolutions - 1:
|
| 216 |
+
h = block.upsample(h)
|
| 217 |
+
|
| 218 |
+
# end
|
| 219 |
+
h = self.norm_out(h)
|
| 220 |
+
h = nonlinearity(h)
|
| 221 |
+
h = self.conv_out(h)
|
| 222 |
+
return h
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class VectorQuantizer(nn.Module):
|
| 226 |
+
def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage=False):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.n_e = n_e
|
| 229 |
+
self.e_dim = e_dim
|
| 230 |
+
self.beta = beta
|
| 231 |
+
self.entropy_loss_ratio = entropy_loss_ratio
|
| 232 |
+
self.l2_norm = l2_norm
|
| 233 |
+
self.show_usage = show_usage
|
| 234 |
+
|
| 235 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
| 236 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
| 237 |
+
|
| 238 |
+
if self.l2_norm:
|
| 239 |
+
self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
|
| 240 |
+
if self.show_usage:
|
| 241 |
+
if self.n_e < 65536:
|
| 242 |
+
self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
|
| 243 |
+
else:
|
| 244 |
+
self.register_buffer("codebook_used", nn.Parameter(torch.zeros(self.n_e+1)))
|
| 245 |
+
# self.register_buffer("codebook_used", nn.Parameter(torch.zeros(196608)))
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# self.h_, self.w_ = int(self.n_e ** 0.5), int(self.n_e ** 0.5)
|
| 249 |
+
if int(self.n_e ** 0.5) ** 2 == self.n_e:
|
| 250 |
+
self.h_, self.w_ = int(self.n_e ** 0.5), int(self.n_e ** 0.5)
|
| 251 |
+
else:
|
| 252 |
+
self.h_ = int((self.n_e * 2) ** 0.5)
|
| 253 |
+
self.w_ = self.n_e // self.h_
|
| 254 |
+
|
| 255 |
+
def forward(self, z):
|
| 256 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 257 |
+
z = torch.einsum('b c h w -> b h w c', z).contiguous()
|
| 258 |
+
z_flattened = z.view(z.shape[0], -1, self.e_dim) # [b, h*w, e_dim]
|
| 259 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 260 |
+
|
| 261 |
+
emb_weights = self.embedding.weight[None].repeat(z.shape[0], 1, 1)
|
| 262 |
+
|
| 263 |
+
if self.l2_norm:
|
| 264 |
+
z = F.normalize(z, p=2, dim=-1)
|
| 265 |
+
z_flattened = F.normalize(z_flattened, p=2, dim=-1)
|
| 266 |
+
embedding = F.normalize(emb_weights, p=2, dim=-1) # [b, n_e, e_dim]
|
| 267 |
+
else:
|
| 268 |
+
embedding = emb_weights
|
| 269 |
+
|
| 270 |
+
d = torch.sum(z_flattened ** 2, dim=2, keepdim=True) + \
|
| 271 |
+
torch.sum(embedding**2, dim=2).unsqueeze(1) - 2 * \
|
| 272 |
+
torch.einsum('bld,bnd->bln', z_flattened, embedding) # [n, h*w, n_e]
|
| 273 |
+
|
| 274 |
+
min_encoding_indices = torch.argmin(d, dim=2) # [n, h*w]
|
| 275 |
+
z_q = torch.stack([embedding[b, min_encoding_indices[b]] for b in range(z.shape[0])]) # [n, h*w, e_dim]
|
| 276 |
+
z_q = z_q.view(z.shape)
|
| 277 |
+
perplexity = None
|
| 278 |
+
min_encodings = None
|
| 279 |
+
vq_loss = None
|
| 280 |
+
commit_loss = None
|
| 281 |
+
entropy_loss = None
|
| 282 |
+
codebook_usage = 0
|
| 283 |
+
|
| 284 |
+
if self.show_usage and self.training:
|
| 285 |
+
self.codebook_used = self.codebook_used.long()
|
| 286 |
+
cur_len = min_encoding_indices.shape.numel()
|
| 287 |
+
self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()
|
| 288 |
+
self.codebook_used[-cur_len:] = min_encoding_indices.view(-1)
|
| 289 |
+
codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# compute loss for embedding
|
| 293 |
+
if self.training:
|
| 294 |
+
vq_loss = torch.mean((z_q - z.detach()) ** 2)
|
| 295 |
+
commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
|
| 296 |
+
entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d.view(-1, d.shape[-1]))
|
| 297 |
+
|
| 298 |
+
# preserve gradients
|
| 299 |
+
z_q = z + (z_q - z).detach()
|
| 300 |
+
|
| 301 |
+
# reshape back to match original input shape
|
| 302 |
+
z_q = torch.einsum('b h w c -> b c h w', z_q)
|
| 303 |
+
|
| 304 |
+
return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices)
|
| 305 |
+
|
| 306 |
+
def get_codebook_entry(self, indices, shape=None, channel_first=True):
|
| 307 |
+
|
| 308 |
+
if self.l2_norm:
|
| 309 |
+
embedding = F.normalize(self.embedding.weight, p=2, dim=-1) # [n, n_e, e_dim]
|
| 310 |
+
else:
|
| 311 |
+
embedding = self.embedding.weight
|
| 312 |
+
|
| 313 |
+
z_q = embedding[indices]
|
| 314 |
+
|
| 315 |
+
if shape is not None:
|
| 316 |
+
if channel_first:
|
| 317 |
+
z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) # [B, H, W, D]
|
| 318 |
+
# reshape back to match original input shape
|
| 319 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous() # [B, D, H, W]
|
| 320 |
+
else:
|
| 321 |
+
z_q = z_q.view(shape)
|
| 322 |
+
return z_q
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class ResnetBlock(nn.Module):
|
| 326 |
+
def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group'):
|
| 327 |
+
super().__init__()
|
| 328 |
+
self.in_channels = in_channels
|
| 329 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 330 |
+
self.out_channels = out_channels
|
| 331 |
+
self.use_conv_shortcut = conv_shortcut
|
| 332 |
+
|
| 333 |
+
self.norm1 = Normalize(in_channels, norm_type)
|
| 334 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 335 |
+
self.norm2 = Normalize(out_channels, norm_type)
|
| 336 |
+
self.dropout = nn.Dropout(dropout)
|
| 337 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 338 |
+
|
| 339 |
+
if self.in_channels != self.out_channels:
|
| 340 |
+
if self.use_conv_shortcut:
|
| 341 |
+
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 342 |
+
else:
|
| 343 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 344 |
+
|
| 345 |
+
def forward(self, x):
|
| 346 |
+
h = x
|
| 347 |
+
h = self.norm1(h)
|
| 348 |
+
h = nonlinearity(h)
|
| 349 |
+
h = self.conv1(h)
|
| 350 |
+
h = self.norm2(h)
|
| 351 |
+
h = nonlinearity(h)
|
| 352 |
+
h = self.dropout(h)
|
| 353 |
+
h = self.conv2(h)
|
| 354 |
+
if self.in_channels != self.out_channels:
|
| 355 |
+
if self.use_conv_shortcut:
|
| 356 |
+
x = self.conv_shortcut(x)
|
| 357 |
+
else:
|
| 358 |
+
x = self.nin_shortcut(x)
|
| 359 |
+
return x+h
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class AttnBlock(nn.Module):
|
| 364 |
+
def __init__(self, in_channels, norm_type='group', num_heads=1):
|
| 365 |
+
super().__init__()
|
| 366 |
+
self.num_heads = num_heads
|
| 367 |
+
assert in_channels % self.num_heads == 0
|
| 368 |
+
|
| 369 |
+
self.norm = Normalize(in_channels, norm_type)
|
| 370 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 371 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 372 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 373 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def forward_single_head(self, x):
|
| 377 |
+
h_ = x
|
| 378 |
+
h_ = self.norm(h_)
|
| 379 |
+
q = self.q(h_)
|
| 380 |
+
k = self.k(h_)
|
| 381 |
+
v = self.v(h_)
|
| 382 |
+
|
| 383 |
+
# compute attention
|
| 384 |
+
b,c,h,w = q.shape
|
| 385 |
+
q = q.reshape(b,c,h*w)
|
| 386 |
+
q = q.permute(0,2,1) # b,hw,c
|
| 387 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
| 388 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 389 |
+
w_ = w_ * (int(c)**(-0.5))
|
| 390 |
+
w_ = F.softmax(w_, dim=2)
|
| 391 |
+
|
| 392 |
+
# attend to values
|
| 393 |
+
v = v.reshape(b,c,h*w)
|
| 394 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
| 395 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 396 |
+
h_ = h_.reshape(b,c,h,w)
|
| 397 |
+
|
| 398 |
+
h_ = self.proj_out(h_)
|
| 399 |
+
|
| 400 |
+
return x+h_
|
| 401 |
+
|
| 402 |
+
def forwar_multi_head(self, x):
|
| 403 |
+
h_ = x
|
| 404 |
+
h_ = self.norm(h_)
|
| 405 |
+
q = self.q(h_)
|
| 406 |
+
k = self.k(h_)
|
| 407 |
+
v = self.v(h_)
|
| 408 |
+
|
| 409 |
+
# compute attention
|
| 410 |
+
b, c, h, w = q.shape
|
| 411 |
+
q = q.reshape(b, self.num_heads, c//self.num_heads, h * w) # b, head, c, hw
|
| 412 |
+
q = q.permute(0, 1, 3, 2) # b, head, hw, c
|
| 413 |
+
k = k.reshape(b, self.num_heads, c//self.num_heads, h * w) # b, head, c, hw
|
| 414 |
+
|
| 415 |
+
# w_ = torch.bmm(q,k) # b,head,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 416 |
+
w_ = q @ k # b,head,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 417 |
+
w_ = w_ * (int(c // self.num_heads) ** (-0.5))
|
| 418 |
+
w_ = torch.nn.functional.softmax(w_, dim=3)
|
| 419 |
+
|
| 420 |
+
# attend to values
|
| 421 |
+
v = v.reshape(b, self.num_heads, c//self.num_heads, h * w) # b, head, c, hw
|
| 422 |
+
|
| 423 |
+
w_ = w_.permute(0, 1, 3, 2) # b,head,hw,hw (first hw of k, second of q)
|
| 424 |
+
h_ = v @ w_ # b, head,c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 425 |
+
h_ = h_.reshape(b, c, h, w)
|
| 426 |
+
|
| 427 |
+
h_ = self.proj_out(h_)
|
| 428 |
+
|
| 429 |
+
return x + h_
|
| 430 |
+
|
| 431 |
+
def forward(self, x):
|
| 432 |
+
if self.num_heads > 1:
|
| 433 |
+
return self.forwar_multi_head(x)
|
| 434 |
+
else:
|
| 435 |
+
return self.forward_single_head(x)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def nonlinearity(x):
|
| 439 |
+
# swish
|
| 440 |
+
return x*torch.sigmoid(x)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def Normalize(in_channels, norm_type='group'):
|
| 444 |
+
assert norm_type in ['group', 'batch']
|
| 445 |
+
if norm_type == 'group':
|
| 446 |
+
return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 447 |
+
elif norm_type == 'batch':
|
| 448 |
+
return nn.SyncBatchNorm(in_channels)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class Upsample(nn.Module):
|
| 452 |
+
def __init__(self, in_channels, with_conv):
|
| 453 |
+
super().__init__()
|
| 454 |
+
self.with_conv = with_conv
|
| 455 |
+
if self.with_conv:
|
| 456 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 457 |
+
|
| 458 |
+
def forward(self, x):
|
| 459 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 460 |
+
if self.with_conv:
|
| 461 |
+
x = self.conv(x)
|
| 462 |
+
return x
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
class Downsample(nn.Module):
|
| 466 |
+
def __init__(self, in_channels, with_conv):
|
| 467 |
+
super().__init__()
|
| 468 |
+
self.with_conv = with_conv
|
| 469 |
+
if self.with_conv:
|
| 470 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 471 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 472 |
+
|
| 473 |
+
def forward(self, x):
|
| 474 |
+
if self.with_conv:
|
| 475 |
+
pad = (0,1,0,1)
|
| 476 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
| 477 |
+
x = self.conv(x)
|
| 478 |
+
else:
|
| 479 |
+
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
| 480 |
+
return x
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
|
| 484 |
+
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
|
| 485 |
+
flat_affinity /= temperature
|
| 486 |
+
probs = F.softmax(flat_affinity, dim=-1)
|
| 487 |
+
log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
|
| 488 |
+
if loss_type == "softmax":
|
| 489 |
+
target_probs = probs
|
| 490 |
+
else:
|
| 491 |
+
raise ValueError("Entropy loss {} not supported".format(loss_type))
|
| 492 |
+
avg_probs = torch.mean(target_probs, dim=0)
|
| 493 |
+
avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
|
| 494 |
+
sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
|
| 495 |
+
loss = sample_entropy - avg_entropy
|
| 496 |
+
return loss
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
#################################################################################
|
| 500 |
+
# VQ Model Configs #
|
| 501 |
+
#################################################################################
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def VQ_Model(config, **kwargs):
|
| 505 |
+
model = VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4, 8], decoder_ch_mult=[1, 2, 2, 4, 8], codebook_size=config.image_token_size, codebook_embed_dim=config.n_embed, z_channels=512, ch=256, attn_num_heads=config.num_heads, **kwargs))
|
| 506 |
+
|
| 507 |
+
pretrained = config.model_path
|
| 508 |
+
if pretrained:
|
| 509 |
+
model.load_trained_weights(pretrained)
|
| 510 |
+
return model
|
star/models/rope_2d.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import copy
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
import time
|
| 8 |
+
import math
|
| 9 |
+
import ast
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
from typing import Dict, Optional, Sequence, List, Tuple
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
import base64
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
from torch.utils.data import Dataset
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from decord import VideoReader
|
| 20 |
+
import transformers
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_rope_index_25(
|
| 24 |
+
spatial_merge_size: Optional[int] = 2,
|
| 25 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 26 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 27 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 28 |
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
| 29 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 30 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 31 |
+
"""
|
| 32 |
+
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
|
| 33 |
+
|
| 34 |
+
Explanation:
|
| 35 |
+
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
| 36 |
+
|
| 37 |
+
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
|
| 38 |
+
Examples:
|
| 39 |
+
input_ids: [T T T T T], here T is for text.
|
| 40 |
+
temporal position_ids: [0, 1, 2, 3, 4]
|
| 41 |
+
height position_ids: [0, 1, 2, 3, 4]
|
| 42 |
+
width position_ids: [0, 1, 2, 3, 4]
|
| 43 |
+
|
| 44 |
+
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
| 45 |
+
and 1D rotary position embedding for text part.
|
| 46 |
+
Examples:
|
| 47 |
+
Temporal (Time): 3 patches, representing different segments of the video in time.
|
| 48 |
+
Height: 2 patches, dividing each frame vertically.
|
| 49 |
+
Width: 2 patches, dividing each frame horizontally.
|
| 50 |
+
We also have some important parameters:
|
| 51 |
+
fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
|
| 52 |
+
tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
|
| 53 |
+
temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
|
| 54 |
+
interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
|
| 55 |
+
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
|
| 56 |
+
vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
|
| 57 |
+
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
|
| 58 |
+
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
| 59 |
+
text temporal position_ids: [101, 102, 103, 104, 105]
|
| 60 |
+
text height position_ids: [101, 102, 103, 104, 105]
|
| 61 |
+
text width position_ids: [101, 102, 103, 104, 105]
|
| 62 |
+
Here we calculate the text start position_ids as the max vision position_ids plus 1.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 66 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 67 |
+
it.
|
| 68 |
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
| 69 |
+
The temporal, height and width of feature shape of each image in LLM.
|
| 70 |
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
| 71 |
+
The temporal, height and width of feature shape of each video in LLM.
|
| 72 |
+
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
|
| 73 |
+
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
|
| 74 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 75 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 76 |
+
|
| 77 |
+
- 1 for tokens that are **not masked**,
|
| 78 |
+
- 0 for tokens that are **masked**.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
|
| 82 |
+
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
| 83 |
+
"""
|
| 84 |
+
image_token_id = 151655
|
| 85 |
+
video_token_id = 151656
|
| 86 |
+
vision_start_token_id = 151652
|
| 87 |
+
mrope_position_deltas = []
|
| 88 |
+
if input_ids is not None and (
|
| 89 |
+
image_grid_thw is not None or video_grid_thw is not None
|
| 90 |
+
):
|
| 91 |
+
total_input_ids = input_ids
|
| 92 |
+
if attention_mask is None:
|
| 93 |
+
attention_mask = torch.ones_like(total_input_ids)
|
| 94 |
+
position_ids = torch.ones(
|
| 95 |
+
3,
|
| 96 |
+
input_ids.shape[0],
|
| 97 |
+
input_ids.shape[1],
|
| 98 |
+
dtype=input_ids.dtype,
|
| 99 |
+
device=input_ids.device,
|
| 100 |
+
)
|
| 101 |
+
image_index, video_index = 0, 0
|
| 102 |
+
attention_mask = attention_mask.to(total_input_ids.device)
|
| 103 |
+
for i, input_ids in enumerate(total_input_ids):
|
| 104 |
+
input_ids = input_ids[attention_mask[i] == 1]
|
| 105 |
+
image_nums, video_nums = 0, 0
|
| 106 |
+
vision_start_indices = torch.argwhere(
|
| 107 |
+
input_ids == vision_start_token_id
|
| 108 |
+
).squeeze(1)
|
| 109 |
+
vision_tokens = input_ids[vision_start_indices + 1]
|
| 110 |
+
image_nums = (vision_tokens == image_token_id).sum()
|
| 111 |
+
video_nums = (vision_tokens == video_token_id).sum()
|
| 112 |
+
input_tokens = input_ids.tolist()
|
| 113 |
+
llm_pos_ids_list: list = []
|
| 114 |
+
st = 0
|
| 115 |
+
remain_images, remain_videos = image_nums, video_nums
|
| 116 |
+
for _ in range(image_nums + video_nums):
|
| 117 |
+
if image_token_id in input_tokens and remain_images > 0:
|
| 118 |
+
ed_image = input_tokens.index(image_token_id, st)
|
| 119 |
+
else:
|
| 120 |
+
ed_image = len(input_tokens) + 1
|
| 121 |
+
if video_token_id in input_tokens and remain_videos > 0:
|
| 122 |
+
ed_video = input_tokens.index(video_token_id, st)
|
| 123 |
+
else:
|
| 124 |
+
ed_video = len(input_tokens) + 1
|
| 125 |
+
if ed_image < ed_video:
|
| 126 |
+
t, h, w = (
|
| 127 |
+
image_grid_thw[image_index][0],
|
| 128 |
+
image_grid_thw[image_index][1],
|
| 129 |
+
image_grid_thw[image_index][2],
|
| 130 |
+
)
|
| 131 |
+
second_per_grid_t = 0
|
| 132 |
+
image_index += 1
|
| 133 |
+
remain_images -= 1
|
| 134 |
+
ed = ed_image
|
| 135 |
+
|
| 136 |
+
else:
|
| 137 |
+
t, h, w = (
|
| 138 |
+
video_grid_thw[video_index][0],
|
| 139 |
+
video_grid_thw[video_index][1],
|
| 140 |
+
video_grid_thw[video_index][2],
|
| 141 |
+
)
|
| 142 |
+
if second_per_grid_ts is not None:
|
| 143 |
+
second_per_grid_t = second_per_grid_ts[video_index]
|
| 144 |
+
else:
|
| 145 |
+
second_per_grid_t = 1.0
|
| 146 |
+
video_index += 1
|
| 147 |
+
remain_videos -= 1
|
| 148 |
+
ed = ed_video
|
| 149 |
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
| 150 |
+
t.item(),
|
| 151 |
+
h.item() // spatial_merge_size,
|
| 152 |
+
w.item() // spatial_merge_size,
|
| 153 |
+
)
|
| 154 |
+
text_len = ed - st
|
| 155 |
+
|
| 156 |
+
st_idx = (
|
| 157 |
+
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 158 |
+
)
|
| 159 |
+
llm_pos_ids_list.append(
|
| 160 |
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
| 164 |
+
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
|
| 165 |
+
|
| 166 |
+
time_tensor = expanded_range * second_per_grid_t * 2
|
| 167 |
+
|
| 168 |
+
time_tensor_long = time_tensor.long()
|
| 169 |
+
t_index = time_tensor_long.flatten()
|
| 170 |
+
|
| 171 |
+
h_index = (
|
| 172 |
+
torch.arange(llm_grid_h)
|
| 173 |
+
.view(1, -1, 1)
|
| 174 |
+
.expand(llm_grid_t, -1, llm_grid_w)
|
| 175 |
+
.flatten()
|
| 176 |
+
)
|
| 177 |
+
w_index = (
|
| 178 |
+
torch.arange(llm_grid_w)
|
| 179 |
+
.view(1, 1, -1)
|
| 180 |
+
.expand(llm_grid_t, llm_grid_h, -1)
|
| 181 |
+
.flatten()
|
| 182 |
+
)
|
| 183 |
+
llm_pos_ids_list.append(
|
| 184 |
+
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
| 185 |
+
)
|
| 186 |
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
| 187 |
+
|
| 188 |
+
if st < len(input_tokens):
|
| 189 |
+
st_idx = (
|
| 190 |
+
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 191 |
+
)
|
| 192 |
+
text_len = len(input_tokens) - st
|
| 193 |
+
llm_pos_ids_list.append(
|
| 194 |
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
| 198 |
+
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
|
| 199 |
+
position_ids.device
|
| 200 |
+
)
|
| 201 |
+
mrope_position_deltas.append(
|
| 202 |
+
llm_positions.max() + 1 - len(total_input_ids[i])
|
| 203 |
+
)
|
| 204 |
+
mrope_position_deltas = torch.tensor(
|
| 205 |
+
mrope_position_deltas, device=input_ids.device
|
| 206 |
+
).unsqueeze(1)
|
| 207 |
+
return position_ids, mrope_position_deltas
|
| 208 |
+
else:
|
| 209 |
+
if attention_mask is not None:
|
| 210 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 211 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 212 |
+
position_ids = (
|
| 213 |
+
position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
| 214 |
+
)
|
| 215 |
+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
|
| 216 |
+
-1, keepdim=True
|
| 217 |
+
)[0]
|
| 218 |
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
| 219 |
+
else:
|
| 220 |
+
position_ids = (
|
| 221 |
+
torch.arange(input_ids.shape[1], device=input_ids.device)
|
| 222 |
+
.view(1, 1, -1)
|
| 223 |
+
.expand(3, input_ids.shape[0], -1)
|
| 224 |
+
)
|
| 225 |
+
mrope_position_deltas = torch.zeros(
|
| 226 |
+
[input_ids.shape[0], 1],
|
| 227 |
+
device=input_ids.device,
|
| 228 |
+
dtype=input_ids.dtype,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return position_ids, mrope_position_deltas
|
| 232 |
+
|