John6666's picture
Upload 3 files
a4ad4d3 verified
# from:
# https://gist.github.com/maedtb/ee16101ca80638011c975ed0c0d78497
# https://github.com/fpgaminer/joycaption/issues/3#issuecomment-2619253277
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
IMAGE_PATH = "dog.jpg"
PROMPT = "Write a long descriptive caption for this image in a formal tone."
MODEL_NAME = "John6666/llama-joycaption-alpha-two-vqa-test-1-nf4"
IS_4BIT_MODE = True
MODEL_NATIVE_DTYPE = torch.bfloat16
# Make example output less random
torch.manual_seed(42)
# If 4bit mode is enabled, build our quantization config
kwargs = {}
if IS_4BIT_MODE:
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=MODEL_NATIVE_DTYPE,
bnb_4bit_quant_storage=MODEL_NATIVE_DTYPE,
)
# Load JoyCaption
# bfloat16 is the native dtype of the LLM used in JoyCaption (Llama 3.1)
# device_map=0 loads the model into the first GPU
device = torch.device('cuda:0')
processor = AutoProcessor.from_pretrained(MODEL_NAME)
llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=MODEL_NATIVE_DTYPE, device_map=device, **kwargs)
llava_model.eval()
# Restore the model's vision's out_proj back to using `nn.Linear` from `nn.Linear4bit`; it is not dynamically quantizable.
if IS_4BIT_MODE:
attention = llava_model.vision_tower.vision_model.head.attention
attention.out_proj = torch.nn.Linear(
attention.embed_dim,
attention.embed_dim,
device=device,
dtype=MODEL_NATIVE_DTYPE)
with torch.no_grad():
# Load image
image = Image.open(IMAGE_PATH)
# Build the conversation
convo = [
{
"role": "system",
"content": "You are a helpful image captioner.",
},
{
"role": "user",
"content": PROMPT,
},
]
# Format the conversation
# WARNING: HF's handling of chat's on Llava models is very fragile. This specific combination of processor.apply_chat_template(), and processor() works
# but if using other combinations always inspect the final input_ids to ensure they are correct. Often times you will end up with multiple <bos> tokens
# if not careful, which can make the model perform poorly.
convo_string = processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
assert isinstance(convo_string, str)
# Process the inputs
inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to('cuda')
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
# Generate the captions
generate_ids = llava_model.generate(
**inputs,
max_new_tokens=300,
do_sample=True,
suppress_tokens=None,
use_cache=True,
temperature=0.6,
top_k=None,
top_p=0.9,
)[0]
# Trim off the prompt
generate_ids = generate_ids[inputs['input_ids'].shape[1]:]
# Decode the caption
caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
caption = caption.strip()
print(caption)