|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig |
|
|
|
|
|
IMAGE_PATH = "dog.jpg" |
|
|
PROMPT = "Write a long descriptive caption for this image in a formal tone." |
|
|
MODEL_NAME = "John6666/llama-joycaption-alpha-two-vqa-test-1-nf4" |
|
|
IS_4BIT_MODE = True |
|
|
MODEL_NATIVE_DTYPE = torch.bfloat16 |
|
|
|
|
|
|
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
kwargs = {} |
|
|
if IS_4BIT_MODE: |
|
|
kwargs['quantization_config'] = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type='nf4', |
|
|
bnb_4bit_compute_dtype=MODEL_NATIVE_DTYPE, |
|
|
bnb_4bit_quant_storage=MODEL_NATIVE_DTYPE, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device('cuda:0') |
|
|
processor = AutoProcessor.from_pretrained(MODEL_NAME) |
|
|
llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=MODEL_NATIVE_DTYPE, device_map=device, **kwargs) |
|
|
llava_model.eval() |
|
|
|
|
|
|
|
|
if IS_4BIT_MODE: |
|
|
attention = llava_model.vision_tower.vision_model.head.attention |
|
|
attention.out_proj = torch.nn.Linear( |
|
|
attention.embed_dim, |
|
|
attention.embed_dim, |
|
|
device=device, |
|
|
dtype=MODEL_NATIVE_DTYPE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
image = Image.open(IMAGE_PATH) |
|
|
|
|
|
|
|
|
convo = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": "You are a helpful image captioner.", |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": PROMPT, |
|
|
}, |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
convo_string = processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True) |
|
|
assert isinstance(convo_string, str) |
|
|
|
|
|
|
|
|
inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to('cuda') |
|
|
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16) |
|
|
|
|
|
|
|
|
generate_ids = llava_model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=300, |
|
|
do_sample=True, |
|
|
suppress_tokens=None, |
|
|
use_cache=True, |
|
|
temperature=0.6, |
|
|
top_k=None, |
|
|
top_p=0.9, |
|
|
)[0] |
|
|
|
|
|
|
|
|
generate_ids = generate_ids[inputs['input_ids'].shape[1]:] |
|
|
|
|
|
|
|
|
caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
|
caption = caption.strip() |
|
|
print(caption) |
|
|
|