File size: 3,238 Bytes
a4ad4d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# from:
# https://gist.github.com/maedtb/ee16101ca80638011c975ed0c0d78497
# https://github.com/fpgaminer/joycaption/issues/3#issuecomment-2619253277
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig

IMAGE_PATH = "dog.jpg"
PROMPT = "Write a long descriptive caption for this image in a formal tone."
MODEL_NAME = "John6666/llama-joycaption-alpha-two-vqa-test-1-nf4"
IS_4BIT_MODE = True
MODEL_NATIVE_DTYPE = torch.bfloat16

# Make example output less random
torch.manual_seed(42)

# If 4bit mode is enabled, build our quantization config
kwargs = {}
if IS_4BIT_MODE:
    kwargs['quantization_config'] = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_use_double_quant=True,
      bnb_4bit_quant_type='nf4',
      bnb_4bit_compute_dtype=MODEL_NATIVE_DTYPE,
      bnb_4bit_quant_storage=MODEL_NATIVE_DTYPE,
  )

# Load JoyCaption
# bfloat16 is the native dtype of the LLM used in JoyCaption (Llama 3.1)
# device_map=0 loads the model into the first GPU
device = torch.device('cuda:0')
processor = AutoProcessor.from_pretrained(MODEL_NAME)
llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=MODEL_NATIVE_DTYPE, device_map=device, **kwargs)
llava_model.eval()

# Restore the model's vision's out_proj back to using `nn.Linear` from `nn.Linear4bit`; it is not dynamically quantizable.
if IS_4BIT_MODE:
  attention = llava_model.vision_tower.vision_model.head.attention
  attention.out_proj = torch.nn.Linear(
      attention.embed_dim,
      attention.embed_dim,
      device=device,
      dtype=MODEL_NATIVE_DTYPE)

with torch.no_grad():
    # Load image
    image = Image.open(IMAGE_PATH)

    # Build the conversation
    convo = [
        {
            "role": "system",
            "content": "You are a helpful image captioner.",
        },
        {
            "role": "user",
            "content": PROMPT,
        },
    ]

    # Format the conversation
    # WARNING: HF's handling of chat's on Llava models is very fragile.  This specific combination of processor.apply_chat_template(), and processor() works
    # but if using other combinations always inspect the final input_ids to ensure they are correct.  Often times you will end up with multiple <bos> tokens
    # if not careful, which can make the model perform poorly.
    convo_string = processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
    assert isinstance(convo_string, str)

    # Process the inputs
    inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to('cuda')
    inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)

    # Generate the captions
    generate_ids = llava_model.generate(
        **inputs,
        max_new_tokens=300,
        do_sample=True,
        suppress_tokens=None,
        use_cache=True,
        temperature=0.6,
        top_k=None,
        top_p=0.9,
    )[0]

    # Trim off the prompt
    generate_ids = generate_ids[inputs['input_ids'].shape[1]:]

    # Decode the caption
    caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    caption = caption.strip()
    print(caption)