nroggendorff commited on
Commit
c1b1497
·
verified ·
1 Parent(s): 3661d37

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +43 -54
train.py CHANGED
@@ -6,9 +6,7 @@ from typing import cast
6
  import os
7
  import shutil
8
  import multiprocessing as mp
9
- from torch.utils.data import DataLoader
10
  from PIL import Image
11
- from functools import partial
12
 
13
 
14
  def load_model(model_name, device_id=0):
@@ -33,17 +31,16 @@ def load_model(model_name, device_id=0):
33
  return processor, model
34
 
35
 
36
- def prepare_image(image):
37
- if isinstance(image, Image.Image):
38
- if image.mode != "RGB":
39
- image = image.convert("RGB")
40
- return image
41
- return image
42
 
 
 
 
 
 
 
43
 
44
- def collate_fn(batch, processor):
45
- images = [prepare_image(item["image"]) for item in batch]
46
-
47
  msg = [
48
  {
49
  "role": "user",
@@ -60,9 +57,36 @@ def collate_fn(batch, processor):
60
  text = processor.apply_chat_template(
61
  msg, add_generation_prompt=True, tokenize=False
62
  )
63
- texts = [text] * len(images)
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- return processor(text=texts, images=images, return_tensors="pt", padding=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
68
  def process_shard(gpu_id, start, end, model_name, batch_size, input_dataset, output_file):
@@ -80,51 +104,16 @@ def process_shard(gpu_id, start, end, model_name, batch_size, input_dataset, out
80
  else:
81
  shard = cast(Dataset, loaded)
82
 
83
- shard.set_format(type="torch", columns=["image"])
84
-
85
- dataloader = DataLoader(
86
- shard,
87
  batch_size=batch_size,
88
- num_workers=4,
89
- pin_memory=True,
90
- collate_fn=partial(collate_fn, processor=processor),
91
- prefetch_factor=2,
92
  )
93
 
94
- all_captions = []
95
- special_tokens = set(processor.tokenizer.all_special_tokens)
96
-
97
- print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
98
-
99
- for batch_idx, inputs in enumerate(dataloader):
100
- inputs = {k: v.to(model.device, non_blocking=True) for k, v in inputs.items()}
101
-
102
- with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
103
- generated = model.generate(
104
- **inputs,
105
- max_new_tokens=128,
106
- do_sample=False,
107
- )
108
-
109
- decoded = processor.batch_decode(generated, skip_special_tokens=False)
110
-
111
- for d in decoded:
112
- if "<|im_start|>assistant" in d:
113
- d = d.split("<|im_start|>assistant")[-1]
114
-
115
- for token in special_tokens:
116
- d = d.replace(token, "")
117
-
118
- d = d.strip()
119
- all_captions.append(d)
120
-
121
- if (batch_idx + 1) % 10 == 0:
122
- print(f"[GPU {gpu_id}] Processed {(batch_idx + 1) * batch_size}/{len(shard)} examples", flush=True)
123
-
124
- result_ds = Dataset.from_dict({"text": all_captions})
125
-
126
  print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)
127
- result_ds.save_to_disk(output_file)
128
 
129
  print(f"[GPU {gpu_id}] Done!", flush=True)
130
  return output_file
 
6
  import os
7
  import shutil
8
  import multiprocessing as mp
 
9
  from PIL import Image
 
10
 
11
 
12
  def load_model(model_name, device_id=0):
 
31
  return processor, model
32
 
33
 
34
+ def caption_batch(batch, processor, model):
35
+ images = batch["image"]
 
 
 
 
36
 
37
+ pil_images = []
38
+ for image in images:
39
+ if isinstance(image, Image.Image):
40
+ if image.mode != "RGB":
41
+ image = image.convert("RGB")
42
+ pil_images.append(image)
43
 
 
 
 
44
  msg = [
45
  {
46
  "role": "user",
 
57
  text = processor.apply_chat_template(
58
  msg, add_generation_prompt=True, tokenize=False
59
  )
60
+ texts = [text] * len(pil_images)
61
+
62
+ inputs = processor(text=texts, images=pil_images, return_tensors="pt", padding=True)
63
+
64
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
65
+
66
+ with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
67
+ generated = model.generate(
68
+ **inputs,
69
+ max_new_tokens=128,
70
+ do_sample=False,
71
+ )
72
 
73
+ decoded = processor.batch_decode(generated, skip_special_tokens=False)
74
+
75
+ captions = []
76
+ special_tokens = set(processor.tokenizer.all_special_tokens)
77
+ for d in decoded:
78
+ if "<|im_start|>assistant" in d:
79
+ d = d.split("<|im_start|>assistant")[-1]
80
+
81
+ for token in special_tokens:
82
+ d = d.replace(token, "")
83
+
84
+ d = d.strip()
85
+ captions.append(d)
86
+
87
+ return {
88
+ "text": captions,
89
+ }
90
 
91
 
92
  def process_shard(gpu_id, start, end, model_name, batch_size, input_dataset, output_file):
 
104
  else:
105
  shard = cast(Dataset, loaded)
106
 
107
+ print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
108
+ result = shard.map(
109
+ lambda batch: caption_batch(batch, processor, model),
110
+ batched=True,
111
  batch_size=batch_size,
112
+ remove_columns=[col for col in shard.column_names if col != "image"],
 
 
 
113
  )
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)
116
+ result.save_to_disk(output_file)
117
 
118
  print(f"[GPU {gpu_id}] Done!", flush=True)
119
  return output_file