STAR / star /models /model.py
MM-MVR's picture
Upload files
97bc03d verified
raw
history blame
29.7 kB
import os
import math
import torch
import requests
from io import BytesIO
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from torch.nn import CrossEntropyLoss
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedModel
)
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, Qwen2VLProcessor
from star.models.config import STARMultiModalConfig
from star.models.pixel_encoder.vq_model import VQ_Model
from star.models.adapter.projector import MlpProjector
from star.models.pixel_decoder.lumina2_decoder import Lumina2Decoder
from star.models.data_process_utils import get_full_transform, get_vq_transform, preprocess_image_gen
from star.models.rope_2d import get_rope_index_25
class STARMultiModal(PreTrainedModel):
def __init__(self, config: STARMultiModalConfig, args=None, **kwargs):
super().__init__(config)
self.config = config
self.args = args if args is not None else kwargs.get("args", None)
# Pixel Encoder Generation
model_name = config.pixel_encoder.model_name
if model_name == "VQ_Model":
self.pixel_encoder = VQ_Model(config.pixel_encoder)
else:
assert None, f"Unsupported {model_name}"
self.pixel_encoder.eval()
# Pixel Adapter Generation
model_name = config.pixel_adapter.model_name
if model_name == "MLP_GELU":
self.pixel_adapter = MlpProjector(config.pixel_adapter)
else:
assert None, f"Unsupported {model_name}"
# Pixel Ouput Head Generation
self.pixel_output_head = torch.nn.Linear(config.pixel_output_head.n_embed, config.pixel_output_head.image_token_size)
if getattr(args, "diffusion_as_decoder") and args.diffusion_as_decoder:
self.diffusion_decoder = Lumina2Decoder(config.pixel_decoder, args)
else:
self.diffusion_decoder = None
# Large Language Model
model_name, model_path = config.language_model.model_name, config.language_model.model_path
if model_name == "Qwen2.5-VL":
self.llm = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="cuda")
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
self.tokenizer = self.processor.tokenizer
self.image_processor = self.processor.image_processor
self.image_processor.max_pixels = self.args.max_pixels
self.image_processor.min_pixels = self.args.min_pixels
self.image_processor.size["longest_edge"] = self.args.max_pixels
self.image_processor.size["shortest_edge"] = self.args.min_pixels
special_token_tags = ["<|vision_start|>", "<|vision_pad|>", "<|image_pad|>", "<|vision_end|>", "<|fim_pad|>"]
self.special_tokens = {tag: self.tokenizer.vocab.get(tag, None) for tag in special_token_tags}
else:
assert None, f"unsupported {model_name}: {model_path}"
self.llm.generation_config.pad_token_id = self.tokenizer.encode(self.tokenizer.pad_token)[0]
if self.args.grad_ckpt:
self.llm.gradient_checkpointing_enable()
self.llm.visual.gradient_checkpointing_enable()
stacked_ar_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
num_layers_to_extract = config.stacked_ar.num_layers
stacked_ar_config.num_hidden_layers = num_layers_to_extract
self.stacked_ar = Qwen2_5_VLForConditionalGeneration(stacked_ar_config)
temp_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
total_layers = len(temp_model.model.layers)
start_layer = max(0, total_layers - num_layers_to_extract)
temp_model.model.layers = temp_model.model.layers[start_layer:]
self.stacked_ar.load_state_dict(temp_model.state_dict(), strict=False)
self.stacked_ar = self.stacked_ar.to("cuda")
del self.stacked_ar.visual, self.stacked_ar.model.embed_tokens, self.stacked_ar.lm_head
# For Inference Generation
def generate_images(self, prompt, max_new_tokens=256, num_return_sequences=1, cfg_weight=5.0, topk_sample=1000, topp_sample=1.0, temperature=1.0, reasoning=False, return_dict=False):
if reasoning:
return self.generate_images_reasoning(prompt, max_new_tokens, num_return_sequences, cfg_weight, topk_sample, topp_sample, temperature, return_dict)
messages = [{'role': 'user', 'content': prompt}]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
text_token = self.tokenizer.encode(text)
text_token = torch.tensor(text_token).long().to(self.device)
keys = list(self.special_tokens.keys())
start_token = (torch.ones(1) * self.special_tokens.get(keys[0])).long().to(self.device)
input_ids = torch.cat((text_token, start_token)).long().to(self.device)
tokens = torch.zeros((num_return_sequences*2, len(input_ids)), dtype=torch.int).cuda()
assistant_tokens = input_ids[-4:]
for i in range(num_return_sequences*2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = self.special_tokens.get(keys[4])
tokens[i, -4:] = assistant_tokens
inputs_embeds = self.llm.model.embed_tokens(tokens).to(self.device)
generated_tokens = torch.zeros((num_return_sequences, max_new_tokens), dtype=torch.int).cuda()
for i in range(max_new_tokens):
outputs = self.llm.model(
inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=outputs.past_key_values if i != 0 else None,
output_hidden_states=True)
last_hidden_states = outputs[0]
output_states = self.stacked_ar.model(
inputs_embeds=last_hidden_states,
past_key_values=output_states.past_key_values if i != 0 else None,
output_hidden_states=True,
use_cache=True)
last_hidden_states = output_states.hidden_states[-1]
logits = self.pixel_output_head(last_hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
next_token, _ = self.sample(logits, temperature=1.0, top_k=topk_sample, top_p=topp_sample)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
vqgan_embeds = self.pixel_encoder.get_codebook_entry(next_token)
img_embeds = self.pixel_adapter(vqgan_embeds)
inputs_embeds = img_embeds.unsqueeze(dim=1)
latent_size = int(math.sqrt(max_new_tokens))
output_images = self.pixel_encoder.decode_code(generated_tokens.to(dtype=torch.int), shape=[num_return_sequences, self.pixel_encoder.config.codebook_embed_dim, latent_size, latent_size])
output_images = output_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
diff_images = None
if self.diffusion_decoder is not None:
gen_image_embeds = self.pixel_encoder.get_codebook_entry(generated_tokens)
if self.args.diffusion_resolution==512:
self.diffusion_decoder.pipe.transformer.config.sample_size=16
elif self.args.diffusion_resolution==1024:
self.diffusion_decoder.pipe.transformer.config.sample_size=32
diff_images = self.diffusion_decoder.pipe(
prompt,
num_inference_steps=40,
guidance_scale=4.5,
gen_image_embeds=gen_image_embeds, #gen_image_embeds,
control_emd="text",
ori_inp_way=self.diffusion_decoder.transformer.ori_inp_dit,
only_t2i="vqconcat",
img_guidance_scale=1.05,
height=self.args.diffusion_resolution,
width=self.args.diffusion_resolution
).images
if return_dict:
return {"output_images": output_images, "generated_tokens": generated_tokens, "diff_images": diff_images}
return output_images
def answer_text_qwen_vl(self, question, max_new_tokens=256, do_sample=True):
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": question},
],
}
]
# Preparation for inference
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=None,
videos=None,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.llm.device)
# Inference: Generation of the output
generated_ids = self.llm.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=do_sample)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0] if output_text else ""
def generate_images_reasoning(self, prompt, max_new_tokens=256, num_return_sequences=1, cfg_weight=5.0, topk_sample=1000, topp_sample=1.0, temperature=1.0, return_dict=False):
messages = [{'role': 'user', 'content': prompt}]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
text_token = self.tokenizer.encode(text)
text_token = torch.tensor(text_token).long().to(self.device)
keys = list(self.special_tokens.keys())
start_token = (torch.ones(1) * self.special_tokens.get(keys[0])).long().to(self.device)
input_ids = torch.cat((text_token, start_token)).long().to(self.device)
tokens = torch.zeros((num_return_sequences*2, len(input_ids)), dtype=torch.int).cuda()
assistant_tokens = input_ids[-4:]
for i in range(num_return_sequences*2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = self.special_tokens.get(keys[4])
tokens[i, -4:] = assistant_tokens
generated_tokens = torch.zeros((num_return_sequences, max_new_tokens), dtype=torch.int).cuda()
answer_tokens_list = self.answer_text_qwen_vl(prompt, do_sample=False)
if answer_tokens_list:
answer_tokens_list = self.tokenizer.encode(answer_tokens_list, add_special_tokens=False)
answer_tokens = torch.tensor([answer_tokens_list], device=self.device) # [1, seq_len]
magic_prompt = " Ultra HD, 4K, cinematic composition"
magic_prompt_tokens = self.tokenizer.encode(magic_prompt, add_special_tokens=False)
magic_prompt_tensor = torch.tensor([magic_prompt_tokens], device=self.device) # [1, magic_seq_len]
answer_tokens = torch.cat([answer_tokens, magic_prompt_tensor], dim=1) # [1, seq_len + magic_seq_len]
answer_prompt = self.tokenizer.decode(answer_tokens[0]).split("assistant\n")[-1] #hjc see
special_token = self.special_tokens.get(keys[4])
special_token_tensor = torch.tensor([[special_token]], device=self.device)
special_token_expanded = special_token_tensor.expand(-1, answer_tokens.size(1))
answer_tokens_with_special = torch.cat([answer_tokens, special_token_expanded], dim=0)
batch_size = tokens.size(0) # num_return_sequences*2
answer_tokens_expanded = answer_tokens_with_special.repeat(batch_size // 2, 1)
input_tokens = torch.cat((tokens[:, :14], answer_tokens_expanded, tokens[:, -6:]), dim=1)
else:
input_tokens = tokens
answer_prompt = None
inputs_embeds = self.llm.model.embed_tokens(input_tokens).to(self.device)
for i in range(max_new_tokens):
outputs = self.llm.model(
inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=outputs.past_key_values if i != 0 else None,
output_hidden_states=True)
last_hidden_states = outputs[0]
output_states = self.stacked_ar.model(
inputs_embeds=last_hidden_states,
past_key_values=output_states.past_key_values if i != 0 else None,
output_hidden_states=True,
use_cache=True)
last_hidden_states = output_states.hidden_states[-1]
logits = self.pixel_output_head(last_hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
next_token, _ = self.sample(logits, temperature=1.0, top_k=topk_sample, top_p=topp_sample)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
vqgan_embeds = self.pixel_encoder.get_codebook_entry(next_token)
img_embeds = self.pixel_adapter(vqgan_embeds)
inputs_embeds = img_embeds.unsqueeze(dim=1)
latent_size = int(math.sqrt(max_new_tokens))
output_images = self.pixel_encoder.decode_code(generated_tokens.to(dtype=torch.int), shape=[num_return_sequences, self.pixel_encoder.config.codebook_embed_dim, latent_size, latent_size])
output_images = output_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
diff_images = None
if self.diffusion_decoder is not None:
gen_image_embeds = self.pixel_encoder.get_codebook_entry(generated_tokens)
diff_prompt = answer_prompt if answer_prompt else prompt
if self.args.diffusion_resolution==512:
self.diffusion_decoder.pipe.transformer.config.sample_size=16
elif self.args.diffusion_resolution==1024:
self.diffusion_decoder.pipe.transformer.config.sample_size=32
diff_images = self.diffusion_decoder.pipe(
diff_prompt,
num_inference_steps=40,
guidance_scale=4.5,
gen_image_embeds=gen_image_embeds, #gen_image_embeds,
control_emd="text",
ori_inp_way=self.diffusion_decoder.transformer.ori_inp_dit,
only_t2i="vqconcat",
img_guidance_scale=1.05,
height=self.args.diffusion_resolution,
width=self.args.diffusion_resolution
).images
if return_dict:
return {"output_images":output_images,"generated_tokens":generated_tokens,"diff_images":diff_images,"answer_prompt":answer_prompt}
return output_images
def generate_images_edit(self, image, prompt, max_new_tokens=256, num_return_sequences=1, cfg_weight=5.0, topk_sample=1000, topp_sample=1.0, temperature=1.0,return_dict=False):
vq_image_transform = get_vq_transform(self.args)
full_image_transform = get_full_transform(self.args)
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif isinstance(image, list):
image = [each_image.convert('RGB') for each_image in image]
else:
image = image.convert('RGB')
messages = [{'role': 'user', 'content': prompt}]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
text_token = self.tokenizer.encode(text)
text_token = torch.tensor(text_token).long().to(self.device)
keys = list(self.special_tokens.keys())
start_token = (torch.ones(1) * self.special_tokens.get(keys[0])).long().to(self.device)
user_prompt = "<|im_start|>user\n"
user_prompt_token = self.tokenizer.encode(user_prompt, add_special_tokens=False)
user_prompt_tensor = torch.tensor(user_prompt_token).long().to(self.device)
windows = text_token.unfold(0, len(user_prompt_tensor), 1)
matches = (windows == user_prompt_tensor).all(dim=1)
image_position = torch.where(matches)[0][0].item() + len(user_prompt_tensor)
input_ids = torch.cat((text_token, start_token)).long().to(self.device)
tokens = torch.zeros((num_return_sequences*2, len(input_ids)), dtype=torch.int).cuda()
assistant_tokens = input_ids[-4:]
for i in range(num_return_sequences*2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = self.special_tokens.get(keys[4])
tokens[i, -4:] = assistant_tokens
inputs_embeds = self.llm.model.embed_tokens(tokens).to(self.device)
position_ids = None
if image is not None:
image_info = preprocess_image_gen(image, self.image_processor, vq_image_transform)
image_embeds = self.llm.visual(image_info["pixel_values"].to(inputs_embeds.device,self.llm.visual.dtype), grid_thw=image_info["image_grid_thw"].to(inputs_embeds.device))
image_embeds = image_embeds[None,:].repeat(2, 1, 1).to(inputs_embeds.device, inputs_embeds.dtype)
vq_pixel_values = image_info["vq_pixel_values"].to(inputs_embeds.device)
B = inputs_embeds.size(0)
if len(vq_pixel_values.shape)==4:
vq_pixel_values = vq_pixel_values[:,None]
N = vq_pixel_values.size(1)
_, _, [_, _, vq_indices] = self.pixel_encoder.encode(vq_pixel_values.flatten(0, 1).bfloat16())
batch_size = vq_pixel_values.shape[0]
vq_indices = vq_indices.reshape(batch_size, N, vq_indices.shape[-1])
vqgan_dec_embeds = self.pixel_encoder.get_codebook_entry(vq_indices)
vq_embeds = self.pixel_adapter(vqgan_dec_embeds)
vq_embeds = vq_embeds.repeat(B, 1, 1, 1).to(inputs_embeds.device, inputs_embeds.dtype).flatten(1, 2)
vision_start_embeds = self.llm.model.embed_tokens(torch.tensor(self.tokenizer.encode("<|vision_start|>")).long().to(self.device))
vision_end_embeds = self.llm.model.embed_tokens(torch.tensor(self.tokenizer.encode("<|vision_end|>")).long().to(self.device))
newline_embeds = self.llm.model.embed_tokens(torch.tensor(self.tokenizer.encode("\n")).long().to(self.device))
vision_start_embeds = vision_start_embeds.unsqueeze(0).repeat(B, 1, 1)
vision_end_embeds = vision_end_embeds.unsqueeze(0).repeat(B, 1, 1)
newline_embeds = newline_embeds.unsqueeze(0).repeat(B, 1, 1)
inputs_embeds = torch.cat((inputs_embeds[:, :image_position],
vision_start_embeds, vq_embeds, vision_end_embeds,
vision_start_embeds, image_embeds, vision_end_embeds, newline_embeds,
inputs_embeds[:, image_position:]), dim=1)
SPECIAL_VQ_TOKEN = '<|vision_pad|>'
SPECIAL_VIT_TOKEN = '<|image_pad|>'
SPECIAL_VQ_TOKEN_ID = self.tokenizer.encode(SPECIAL_VQ_TOKEN)[0]
SPECIAL_VIT_TOKEN_ID = self.tokenizer.encode(SPECIAL_VIT_TOKEN)[0]
input_ids_for_position = torch.cat([input_ids[:image_position],
torch.tensor(self.tokenizer.encode("<|vision_start|>")).to(vq_embeds.device), torch.full((vq_embeds.shape[-2],), SPECIAL_VQ_TOKEN_ID, device=vq_embeds.device), torch.tensor(self.tokenizer.encode("<|vision_end|>")).to(vq_embeds.device),
torch.tensor(self.tokenizer.encode("<|vision_start|>")).to(vq_embeds.device), torch.full((image_embeds.shape[-2],), SPECIAL_VIT_TOKEN_ID, device=vq_embeds.device), torch.tensor(self.tokenizer.encode("<|vision_end|>")).to(vq_embeds.device), torch.tensor(self.tokenizer.encode("\n")).to(vq_embeds.device),
input_ids[image_position:],torch.full((vq_embeds.shape[-2],), SPECIAL_VQ_TOKEN_ID, device=vq_embeds.device)], dim=0)
position_ids, _ = get_rope_index_25(
self.image_processor.merge_size,
input_ids_for_position[None],
image_grid_thw=image_info["image_grid_thw"],
video_grid_thw=None,
second_per_grid_ts=None,
)
generated_tokens = torch.zeros((num_return_sequences, max_new_tokens), dtype=torch.int).cuda()
for i in range(max_new_tokens):
if i != 0:
real_position = position_ids[:,:,outputs.past_key_values.seen_tokens:(outputs.past_key_values.seen_tokens+inputs_embeds.shape[1])].to(inputs_embeds.device)
else:
real_position = position_ids[:,:,:inputs_embeds.shape[1]].to(inputs_embeds.device)
outputs = self.llm.model(
inputs_embeds=inputs_embeds,
use_cache=True,
position_ids = real_position,
past_key_values=outputs.past_key_values if i != 0 else None,
output_hidden_states=True)
last_hidden_states = outputs[0]
output_states = self.stacked_ar.model(
inputs_embeds=last_hidden_states,
past_key_values=output_states.past_key_values if i != 0 else None,
output_hidden_states=True,
position_ids = real_position,
use_cache=True)
last_hidden_states = output_states.hidden_states[-1]
logits = self.pixel_output_head(last_hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
next_token, _ = self.sample(logits, temperature=1.0, top_k=topk_sample, top_p=topp_sample)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
vqgan_embeds = self.pixel_encoder.get_codebook_entry(next_token)
img_embeds = self.pixel_adapter(vqgan_embeds)
inputs_embeds = img_embeds.unsqueeze(dim=1)
latent_size = int(math.sqrt(max_new_tokens))
output_images = self.pixel_encoder.decode_code(generated_tokens.to(dtype=torch.int), shape=[num_return_sequences, self.pixel_encoder.config.codebook_embed_dim, latent_size, latent_size])
output_images = output_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
diff_images = None
if self.diffusion_decoder is not None:
gen_image_embeds = self.pixel_encoder.get_codebook_entry(generated_tokens)
if isinstance(image, list):
processed_img = [full_image_transform(each_image) for each_image in image]
else:
processed_img = [full_image_transform(image)]
if self.args.diffusion_resolution==512:
self.diffusion_decoder.pipe.transformer.config.sample_size=16
elif self.args.diffusion_resolution==1024:
self.diffusion_decoder.pipe.transformer.config.sample_size=32
diff_images = self.diffusion_decoder.pipe(
prompt,
num_inference_steps=50,
guidance_scale=3.0,
gen_image_embeds=gen_image_embeds, #gen_image_embeds,
control_emd="text",ori_inp_img=processed_img[0],ori_inp_way="seq",
only_t2i="vqconcat",img_guidance_scale=1.8,vq_guidance_scale=1,height=self.args.diffusion_resolution,width=self.args.diffusion_resolution
).images
if return_dict:
return {"output_images": output_images, "generated_tokens": None, "diff_images": diff_images}
return None
def sample(self, logits, temperature: float=1.0, top_k: int=0, top_p: float=1.0, sample_logits=True):
logits = logits / max(temperature, 1e-5)
if top_k > 0 or top_p < 1.0:
logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=-1)
if sample_logits:
idx = torch.multinomial(probs, num_samples=1)
else:
_, idx = torch.topk(probs, k=1, dim=-1)
return idx, probs
def top_k_top_p_filtering(
self,
logits,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
# For Inference Understand
def preprocess_image(self, image):
if image is None:
return None
if isinstance(image, str):
if os.path.exists(image):
pil_image = Image.open(image).convert('RGB')
else:
response = requests.get(image)
if response.status_code == 200:
image_bytes = BytesIO(response.content)
pil_image = Image.open(image_bytes).convert('RGB')
else:
raise ValueError(f"Failed to load image from url {image}")
elif isinstance(image, Image.Image):
pil_image = image.convert('RGB')
elif isinstance(image, list):
return self.preprocess_image(image[0])
else:
raise ValueError("Unsupported image type")
return pil_image
def inference_understand(self, image, question, max_new_tokens=256):
pil_image = self.preprocess_image(image)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": pil_image,
},
{"type": "text", "text": question},
],
}
]
from qwen_vl_utils import process_vision_info
# Preparation for inference
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.llm.device)
# Inference: Generation of the output
generated_ids = self.llm.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0] if output_text else ""