STAR / star /models /model.py
MM-MVR's picture
Upload files
97bc03d verified
import os
import math
import torch
import requests
from io import BytesIO
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from torch.nn import CrossEntropyLoss
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedModel
)
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, Qwen2VLProcessor
from star.models.config import STARMultiModalConfig
from star.models.pixel_encoder.vq_model import VQ_Model
from star.models.adapter.projector import MlpProjector
from star.models.pixel_decoder.lumina2_decoder import Lumina2Decoder
from star.models.data_process_utils import get_full_transform, get_vq_transform, preprocess_image_gen
from star.models.rope_2d import get_rope_index_25
class STARMultiModal(PreTrainedModel):
def __init__(self, config: STARMultiModalConfig, args=None, **kwargs):
super().__init__(config)
self.config = config
self.args = args if args is not None else kwargs.get("args", None)
# Pixel Encoder Generation
model_name = config.pixel_encoder.model_name
if model_name == "VQ_Model":
self.pixel_encoder = VQ_Model(config.pixel_encoder)
else:
assert None, f"Unsupported {model_name}"
self.pixel_encoder.eval()
# Pixel Adapter Generation
model_name = config.pixel_adapter.model_name
if model_name == "MLP_GELU":
self.pixel_adapter = MlpProjector(config.pixel_adapter)
else:
assert None, f"Unsupported {model_name}"
# Pixel Ouput Head Generation
self.pixel_output_head = torch.nn.Linear(config.pixel_output_head.n_embed, config.pixel_output_head.image_token_size)
if getattr(args, "diffusion_as_decoder") and args.diffusion_as_decoder:
self.diffusion_decoder = Lumina2Decoder(config.pixel_decoder, args)
else:
self.diffusion_decoder = None
# Large Language Model
model_name, model_path = config.language_model.model_name, config.language_model.model_path
if model_name == "Qwen2.5-VL":
self.llm = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="cuda")
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
self.tokenizer = self.processor.tokenizer
self.image_processor = self.processor.image_processor
self.image_processor.max_pixels = self.args.max_pixels
self.image_processor.min_pixels = self.args.min_pixels
self.image_processor.size["longest_edge"] = self.args.max_pixels
self.image_processor.size["shortest_edge"] = self.args.min_pixels
special_token_tags = ["<|vision_start|>", "<|vision_pad|>", "<|image_pad|>", "<|vision_end|>", "<|fim_pad|>"]
self.special_tokens = {tag: self.tokenizer.vocab.get(tag, None) for tag in special_token_tags}
else:
assert None, f"unsupported {model_name}: {model_path}"
self.llm.generation_config.pad_token_id = self.tokenizer.encode(self.tokenizer.pad_token)[0]
if self.args.grad_ckpt:
self.llm.gradient_checkpointing_enable()
self.llm.visual.gradient_checkpointing_enable()
stacked_ar_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
num_layers_to_extract = config.stacked_ar.num_layers
stacked_ar_config.num_hidden_layers = num_layers_to_extract
self.stacked_ar = Qwen2_5_VLForConditionalGeneration(stacked_ar_config)
temp_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
total_layers = len(temp_model.model.layers)
start_layer = max(0, total_layers - num_layers_to_extract)
temp_model.model.layers = temp_model.model.layers[start_layer:]
self.stacked_ar.load_state_dict(temp_model.state_dict(), strict=False)
self.stacked_ar = self.stacked_ar.to("cuda")
del self.stacked_ar.visual, self.stacked_ar.model.embed_tokens, self.stacked_ar.lm_head
# For Inference Generation
def generate_images(self, prompt, max_new_tokens=256, num_return_sequences=1, cfg_weight=5.0, topk_sample=1000, topp_sample=1.0, temperature=1.0, reasoning=False, return_dict=False):
if reasoning:
return self.generate_images_reasoning(prompt, max_new_tokens, num_return_sequences, cfg_weight, topk_sample, topp_sample, temperature, return_dict)
messages = [{'role': 'user', 'content': prompt}]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
text_token = self.tokenizer.encode(text)
text_token = torch.tensor(text_token).long().to(self.device)
keys = list(self.special_tokens.keys())
start_token = (torch.ones(1) * self.special_tokens.get(keys[0])).long().to(self.device)
input_ids = torch.cat((text_token, start_token)).long().to(self.device)
tokens = torch.zeros((num_return_sequences*2, len(input_ids)), dtype=torch.int).cuda()
assistant_tokens = input_ids[-4:]
for i in range(num_return_sequences*2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = self.special_tokens.get(keys[4])
tokens[i, -4:] = assistant_tokens
inputs_embeds = self.llm.model.embed_tokens(tokens).to(self.device)
generated_tokens = torch.zeros((num_return_sequences, max_new_tokens), dtype=torch.int).cuda()
for i in range(max_new_tokens):
outputs = self.llm.model(
inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=outputs.past_key_values if i != 0 else None,
output_hidden_states=True)
last_hidden_states = outputs[0]
output_states = self.stacked_ar.model(
inputs_embeds=last_hidden_states,
past_key_values=output_states.past_key_values if i != 0 else None,
output_hidden_states=True,
use_cache=True)
last_hidden_states = output_states.hidden_states[-1]
logits = self.pixel_output_head(last_hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
next_token, _ = self.sample(logits, temperature=1.0, top_k=topk_sample, top_p=topp_sample)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
vqgan_embeds = self.pixel_encoder.get_codebook_entry(next_token)
img_embeds = self.pixel_adapter(vqgan_embeds)
inputs_embeds = img_embeds.unsqueeze(dim=1)
latent_size = int(math.sqrt(max_new_tokens))
output_images = self.pixel_encoder.decode_code(generated_tokens.to(dtype=torch.int), shape=[num_return_sequences, self.pixel_encoder.config.codebook_embed_dim, latent_size, latent_size])
output_images = output_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
diff_images = None
if self.diffusion_decoder is not None:
gen_image_embeds = self.pixel_encoder.get_codebook_entry(generated_tokens)
if self.args.diffusion_resolution==512:
self.diffusion_decoder.pipe.transformer.config.sample_size=16
elif self.args.diffusion_resolution==1024:
self.diffusion_decoder.pipe.transformer.config.sample_size=32
diff_images = self.diffusion_decoder.pipe(
prompt,
num_inference_steps=40,
guidance_scale=4.5,
gen_image_embeds=gen_image_embeds, #gen_image_embeds,
control_emd="text",
ori_inp_way=self.diffusion_decoder.transformer.ori_inp_dit,
only_t2i="vqconcat",
img_guidance_scale=1.05,
height=self.args.diffusion_resolution,
width=self.args.diffusion_resolution
).images
if return_dict:
return {"output_images": output_images, "generated_tokens": generated_tokens, "diff_images": diff_images}
return output_images
def answer_text_qwen_vl(self, question, max_new_tokens=256, do_sample=True):
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": question},
],
}
]
# Preparation for inference
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=None,
videos=None,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.llm.device)
# Inference: Generation of the output
generated_ids = self.llm.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=do_sample)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0] if output_text else ""
def generate_images_reasoning(self, prompt, max_new_tokens=256, num_return_sequences=1, cfg_weight=5.0, topk_sample=1000, topp_sample=1.0, temperature=1.0, return_dict=False):
messages = [{'role': 'user', 'content': prompt}]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
text_token = self.tokenizer.encode(text)
text_token = torch.tensor(text_token).long().to(self.device)
keys = list(self.special_tokens.keys())
start_token = (torch.ones(1) * self.special_tokens.get(keys[0])).long().to(self.device)
input_ids = torch.cat((text_token, start_token)).long().to(self.device)
tokens = torch.zeros((num_return_sequences*2, len(input_ids)), dtype=torch.int).cuda()
assistant_tokens = input_ids[-4:]
for i in range(num_return_sequences*2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = self.special_tokens.get(keys[4])
tokens[i, -4:] = assistant_tokens
generated_tokens = torch.zeros((num_return_sequences, max_new_tokens), dtype=torch.int).cuda()
answer_tokens_list = self.answer_text_qwen_vl(prompt, do_sample=False)
if answer_tokens_list:
answer_tokens_list = self.tokenizer.encode(answer_tokens_list, add_special_tokens=False)
answer_tokens = torch.tensor([answer_tokens_list], device=self.device) # [1, seq_len]
magic_prompt = " Ultra HD, 4K, cinematic composition"
magic_prompt_tokens = self.tokenizer.encode(magic_prompt, add_special_tokens=False)
magic_prompt_tensor = torch.tensor([magic_prompt_tokens], device=self.device) # [1, magic_seq_len]
answer_tokens = torch.cat([answer_tokens, magic_prompt_tensor], dim=1) # [1, seq_len + magic_seq_len]
answer_prompt = self.tokenizer.decode(answer_tokens[0]).split("assistant\n")[-1] #hjc see
special_token = self.special_tokens.get(keys[4])
special_token_tensor = torch.tensor([[special_token]], device=self.device)
special_token_expanded = special_token_tensor.expand(-1, answer_tokens.size(1))
answer_tokens_with_special = torch.cat([answer_tokens, special_token_expanded], dim=0)
batch_size = tokens.size(0) # num_return_sequences*2
answer_tokens_expanded = answer_tokens_with_special.repeat(batch_size // 2, 1)
input_tokens = torch.cat((tokens[:, :14], answer_tokens_expanded, tokens[:, -6:]), dim=1)
else:
input_tokens = tokens
answer_prompt = None
inputs_embeds = self.llm.model.embed_tokens(input_tokens).to(self.device)
for i in range(max_new_tokens):
outputs = self.llm.model(
inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=outputs.past_key_values if i != 0 else None,
output_hidden_states=True)
last_hidden_states = outputs[0]
output_states = self.stacked_ar.model(
inputs_embeds=last_hidden_states,
past_key_values=output_states.past_key_values if i != 0 else None,
output_hidden_states=True,
use_cache=True)
last_hidden_states = output_states.hidden_states[-1]
logits = self.pixel_output_head(last_hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
next_token, _ = self.sample(logits, temperature=1.0, top_k=topk_sample, top_p=topp_sample)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
vqgan_embeds = self.pixel_encoder.get_codebook_entry(next_token)
img_embeds = self.pixel_adapter(vqgan_embeds)
inputs_embeds = img_embeds.unsqueeze(dim=1)
latent_size = int(math.sqrt(max_new_tokens))
output_images = self.pixel_encoder.decode_code(generated_tokens.to(dtype=torch.int), shape=[num_return_sequences, self.pixel_encoder.config.codebook_embed_dim, latent_size, latent_size])
output_images = output_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
diff_images = None
if self.diffusion_decoder is not None:
gen_image_embeds = self.pixel_encoder.get_codebook_entry(generated_tokens)
diff_prompt = answer_prompt if answer_prompt else prompt
if self.args.diffusion_resolution==512:
self.diffusion_decoder.pipe.transformer.config.sample_size=16
elif self.args.diffusion_resolution==1024:
self.diffusion_decoder.pipe.transformer.config.sample_size=32
diff_images = self.diffusion_decoder.pipe(
diff_prompt,
num_inference_steps=40,
guidance_scale=4.5,
gen_image_embeds=gen_image_embeds, #gen_image_embeds,
control_emd="text",
ori_inp_way=self.diffusion_decoder.transformer.ori_inp_dit,
only_t2i="vqconcat",
img_guidance_scale=1.05,
height=self.args.diffusion_resolution,
width=self.args.diffusion_resolution
).images
if return_dict:
return {"output_images":output_images,"generated_tokens":generated_tokens,"diff_images":diff_images,"answer_prompt":answer_prompt}
return output_images
def generate_images_edit(self, image, prompt, max_new_tokens=256, num_return_sequences=1, cfg_weight=5.0, topk_sample=1000, topp_sample=1.0, temperature=1.0,return_dict=False):
vq_image_transform = get_vq_transform(self.args)
full_image_transform = get_full_transform(self.args)
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif isinstance(image, list):
image = [each_image.convert('RGB') for each_image in image]
else:
image = image.convert('RGB')
messages = [{'role': 'user', 'content': prompt}]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
text_token = self.tokenizer.encode(text)
text_token = torch.tensor(text_token).long().to(self.device)
keys = list(self.special_tokens.keys())
start_token = (torch.ones(1) * self.special_tokens.get(keys[0])).long().to(self.device)
user_prompt = "<|im_start|>user\n"
user_prompt_token = self.tokenizer.encode(user_prompt, add_special_tokens=False)
user_prompt_tensor = torch.tensor(user_prompt_token).long().to(self.device)
windows = text_token.unfold(0, len(user_prompt_tensor), 1)
matches = (windows == user_prompt_tensor).all(dim=1)
image_position = torch.where(matches)[0][0].item() + len(user_prompt_tensor)
input_ids = torch.cat((text_token, start_token)).long().to(self.device)
tokens = torch.zeros((num_return_sequences*2, len(input_ids)), dtype=torch.int).cuda()
assistant_tokens = input_ids[-4:]
for i in range(num_return_sequences*2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = self.special_tokens.get(keys[4])
tokens[i, -4:] = assistant_tokens
inputs_embeds = self.llm.model.embed_tokens(tokens).to(self.device)
position_ids = None
if image is not None:
image_info = preprocess_image_gen(image, self.image_processor, vq_image_transform)
image_embeds = self.llm.visual(image_info["pixel_values"].to(inputs_embeds.device,self.llm.visual.dtype), grid_thw=image_info["image_grid_thw"].to(inputs_embeds.device))
image_embeds = image_embeds[None,:].repeat(2, 1, 1).to(inputs_embeds.device, inputs_embeds.dtype)
vq_pixel_values = image_info["vq_pixel_values"].to(inputs_embeds.device)
B = inputs_embeds.size(0)
if len(vq_pixel_values.shape)==4:
vq_pixel_values = vq_pixel_values[:,None]
N = vq_pixel_values.size(1)
_, _, [_, _, vq_indices] = self.pixel_encoder.encode(vq_pixel_values.flatten(0, 1).bfloat16())
batch_size = vq_pixel_values.shape[0]
vq_indices = vq_indices.reshape(batch_size, N, vq_indices.shape[-1])
vqgan_dec_embeds = self.pixel_encoder.get_codebook_entry(vq_indices)
vq_embeds = self.pixel_adapter(vqgan_dec_embeds)
vq_embeds = vq_embeds.repeat(B, 1, 1, 1).to(inputs_embeds.device, inputs_embeds.dtype).flatten(1, 2)
vision_start_embeds = self.llm.model.embed_tokens(torch.tensor(self.tokenizer.encode("<|vision_start|>")).long().to(self.device))
vision_end_embeds = self.llm.model.embed_tokens(torch.tensor(self.tokenizer.encode("<|vision_end|>")).long().to(self.device))
newline_embeds = self.llm.model.embed_tokens(torch.tensor(self.tokenizer.encode("\n")).long().to(self.device))
vision_start_embeds = vision_start_embeds.unsqueeze(0).repeat(B, 1, 1)
vision_end_embeds = vision_end_embeds.unsqueeze(0).repeat(B, 1, 1)
newline_embeds = newline_embeds.unsqueeze(0).repeat(B, 1, 1)
inputs_embeds = torch.cat((inputs_embeds[:, :image_position],
vision_start_embeds, vq_embeds, vision_end_embeds,
vision_start_embeds, image_embeds, vision_end_embeds, newline_embeds,
inputs_embeds[:, image_position:]), dim=1)
SPECIAL_VQ_TOKEN = '<|vision_pad|>'
SPECIAL_VIT_TOKEN = '<|image_pad|>'
SPECIAL_VQ_TOKEN_ID = self.tokenizer.encode(SPECIAL_VQ_TOKEN)[0]
SPECIAL_VIT_TOKEN_ID = self.tokenizer.encode(SPECIAL_VIT_TOKEN)[0]
input_ids_for_position = torch.cat([input_ids[:image_position],
torch.tensor(self.tokenizer.encode("<|vision_start|>")).to(vq_embeds.device), torch.full((vq_embeds.shape[-2],), SPECIAL_VQ_TOKEN_ID, device=vq_embeds.device), torch.tensor(self.tokenizer.encode("<|vision_end|>")).to(vq_embeds.device),
torch.tensor(self.tokenizer.encode("<|vision_start|>")).to(vq_embeds.device), torch.full((image_embeds.shape[-2],), SPECIAL_VIT_TOKEN_ID, device=vq_embeds.device), torch.tensor(self.tokenizer.encode("<|vision_end|>")).to(vq_embeds.device), torch.tensor(self.tokenizer.encode("\n")).to(vq_embeds.device),
input_ids[image_position:],torch.full((vq_embeds.shape[-2],), SPECIAL_VQ_TOKEN_ID, device=vq_embeds.device)], dim=0)
position_ids, _ = get_rope_index_25(
self.image_processor.merge_size,
input_ids_for_position[None],
image_grid_thw=image_info["image_grid_thw"],
video_grid_thw=None,
second_per_grid_ts=None,
)
generated_tokens = torch.zeros((num_return_sequences, max_new_tokens), dtype=torch.int).cuda()
for i in range(max_new_tokens):
if i != 0:
real_position = position_ids[:,:,outputs.past_key_values.seen_tokens:(outputs.past_key_values.seen_tokens+inputs_embeds.shape[1])].to(inputs_embeds.device)
else:
real_position = position_ids[:,:,:inputs_embeds.shape[1]].to(inputs_embeds.device)
outputs = self.llm.model(
inputs_embeds=inputs_embeds,
use_cache=True,
position_ids = real_position,
past_key_values=outputs.past_key_values if i != 0 else None,
output_hidden_states=True)
last_hidden_states = outputs[0]
output_states = self.stacked_ar.model(
inputs_embeds=last_hidden_states,
past_key_values=output_states.past_key_values if i != 0 else None,
output_hidden_states=True,
position_ids = real_position,
use_cache=True)
last_hidden_states = output_states.hidden_states[-1]
logits = self.pixel_output_head(last_hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
next_token, _ = self.sample(logits, temperature=1.0, top_k=topk_sample, top_p=topp_sample)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
vqgan_embeds = self.pixel_encoder.get_codebook_entry(next_token)
img_embeds = self.pixel_adapter(vqgan_embeds)
inputs_embeds = img_embeds.unsqueeze(dim=1)
latent_size = int(math.sqrt(max_new_tokens))
output_images = self.pixel_encoder.decode_code(generated_tokens.to(dtype=torch.int), shape=[num_return_sequences, self.pixel_encoder.config.codebook_embed_dim, latent_size, latent_size])
output_images = output_images.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
diff_images = None
if self.diffusion_decoder is not None:
gen_image_embeds = self.pixel_encoder.get_codebook_entry(generated_tokens)
if isinstance(image, list):
processed_img = [full_image_transform(each_image) for each_image in image]
else:
processed_img = [full_image_transform(image)]
if self.args.diffusion_resolution==512:
self.diffusion_decoder.pipe.transformer.config.sample_size=16
elif self.args.diffusion_resolution==1024:
self.diffusion_decoder.pipe.transformer.config.sample_size=32
diff_images = self.diffusion_decoder.pipe(
prompt,
num_inference_steps=50,
guidance_scale=3.0,
gen_image_embeds=gen_image_embeds, #gen_image_embeds,
control_emd="text",ori_inp_img=processed_img[0],ori_inp_way="seq",
only_t2i="vqconcat",img_guidance_scale=1.8,vq_guidance_scale=1,height=self.args.diffusion_resolution,width=self.args.diffusion_resolution
).images
if return_dict:
return {"output_images": output_images, "generated_tokens": None, "diff_images": diff_images}
return None
def sample(self, logits, temperature: float=1.0, top_k: int=0, top_p: float=1.0, sample_logits=True):
logits = logits / max(temperature, 1e-5)
if top_k > 0 or top_p < 1.0:
logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=-1)
if sample_logits:
idx = torch.multinomial(probs, num_samples=1)
else:
_, idx = torch.topk(probs, k=1, dim=-1)
return idx, probs
def top_k_top_p_filtering(
self,
logits,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
# For Inference Understand
def preprocess_image(self, image):
if image is None:
return None
if isinstance(image, str):
if os.path.exists(image):
pil_image = Image.open(image).convert('RGB')
else:
response = requests.get(image)
if response.status_code == 200:
image_bytes = BytesIO(response.content)
pil_image = Image.open(image_bytes).convert('RGB')
else:
raise ValueError(f"Failed to load image from url {image}")
elif isinstance(image, Image.Image):
pil_image = image.convert('RGB')
elif isinstance(image, list):
return self.preprocess_image(image[0])
else:
raise ValueError("Unsupported image type")
return pil_image
def inference_understand(self, image, question, max_new_tokens=256):
pil_image = self.preprocess_image(image)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": pil_image,
},
{"type": "text", "text": question},
],
}
]
from qwen_vl_utils import process_vision_info
# Preparation for inference
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.llm.device)
# Inference: Generation of the output
generated_ids = self.llm.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0] if output_text else ""