nroggendorff commited on
Commit
73418cf
·
verified ·
1 Parent(s): 84f4c93

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +45 -36
train.py CHANGED
@@ -67,9 +67,15 @@ def preprocess_example_batch(examples, text):
67
  }
68
 
69
 
70
- def run_preprocessing(input_dataset, output_dir, num_proc=32, batch_size=100):
71
  print("Loading dataset for preprocessing...")
72
  ds = datasets.load_dataset(input_dataset, split="train")
 
 
 
 
 
 
73
 
74
  print("Loading processor...")
75
  processor = AutoProcessor.from_pretrained("datalab-to/chandra")
@@ -99,7 +105,7 @@ def caption_batch(batch, processor, model):
99
  k: v.pin_memory().to(model.device, non_blocking=True) for k, v in inputs.items()
100
  }
101
 
102
- with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): # type: ignore
103
  generated = model.generate(
104
  **inputs,
105
  max_new_tokens=128,
@@ -160,36 +166,34 @@ def process_shard(
160
  def main():
161
  mp.set_start_method("spawn", force=True)
162
 
 
 
163
  input_dataset = "none-yet/anime-captions"
164
- preprocessed_dataset = "temp_preprocessed"
165
  output_dataset = "nroggendorff/anime-captions"
166
  model_name = "datalab-to/chandra"
167
  batch_size = 20
168
 
169
- init_flag = os.environ.get("INIT", "0")
170
- is_first_run = init_flag == "0"
171
- is_second_run = init_flag == "1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  if not os.path.exists(preprocessed_dataset):
174
- print(f"[{'First' if is_first_run else 'Second'} Run] Running preprocessing...")
175
- ds_full = datasets.load_dataset(input_dataset, split="train")
176
- total_size = len(ds_full)
177
- midpoint = total_size // 2
178
-
179
- if is_first_run:
180
- ds_to_process = ds_full.select(range(0, midpoint))
181
- else:
182
- ds_to_process = ds_full.select(range(midpoint, total_size))
183
-
184
- print(
185
- f"[{'First' if is_first_run else 'Second'} Run] Saving selected shard to disk..."
186
- )
187
- ds_to_process.save_to_disk("temp_input_shard")
188
-
189
- run_preprocessing("temp_input_shard", preprocessed_dataset)
190
-
191
- # Clean up temp input shard
192
- shutil.rmtree("temp_input_shard")
193
 
194
  print("Loading preprocessed dataset...")
195
  ds = datasets.load_from_disk(preprocessed_dataset)
@@ -207,7 +211,7 @@ def main():
207
  for i in range(num_gpus):
208
  start = i * shard_size
209
  end = start + shard_size if i < num_gpus - 1 else total_size
210
- output_file = f"temp_shard_{i}"
211
  temp_files.append(output_file)
212
 
213
  p = mp.Process(
@@ -242,16 +246,21 @@ def main():
242
  shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files]
243
  final_ds = datasets.concatenate_datasets(shards)
244
 
245
- if is_first_run:
246
- print("First run: pushing first half to hub...")
247
- final_ds.push_to_hub(output_dataset, create_pr=False)
 
 
248
  else:
249
- print("Second run: loading first half and merging...")
250
- first_half_ds = datasets.load_dataset(output_dataset, split="train")
251
- merged_ds = datasets.concatenate_datasets([first_half_ds, final_ds])
252
- print(f"Final merged dataset size: {len(merged_ds)}")
253
- print("Pushing full dataset with create_pr=True...")
254
- merged_ds.push_to_hub(output_dataset, create_pr=True)
 
 
 
255
 
256
  print("Cleaning up temporary files...")
257
  for f in temp_files:
@@ -264,4 +273,4 @@ def main():
264
 
265
 
266
  if __name__ == "__main__":
267
- main()
 
67
  }
68
 
69
 
70
+ def run_preprocessing(input_dataset, output_dir, num_proc=32, batch_size=100, start_idx=0, end_idx=None):
71
  print("Loading dataset for preprocessing...")
72
  ds = datasets.load_dataset(input_dataset, split="train")
73
+
74
+ if end_idx is None:
75
+ end_idx = len(ds)
76
+
77
+ print(f"Selecting range [{start_idx}:{end_idx}]...")
78
+ ds = ds.select(range(start_idx, end_idx))
79
 
80
  print("Loading processor...")
81
  processor = AutoProcessor.from_pretrained("datalab-to/chandra")
 
105
  k: v.pin_memory().to(model.device, non_blocking=True) for k, v in inputs.items()
106
  }
107
 
108
+ with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
109
  generated = model.generate(
110
  **inputs,
111
  max_new_tokens=128,
 
166
  def main():
167
  mp.set_start_method("spawn", force=True)
168
 
169
+ init_stage = os.environ.get("INIT", "0")
170
+
171
  input_dataset = "none-yet/anime-captions"
 
172
  output_dataset = "nroggendorff/anime-captions"
173
  model_name = "datalab-to/chandra"
174
  batch_size = 20
175
 
176
+ print(f"Running stage INIT={init_stage}")
177
+
178
+ full_ds = datasets.load_dataset(input_dataset, split="train")
179
+ total_dataset_size = len(full_ds)
180
+ midpoint = total_dataset_size // 2
181
+
182
+ if init_stage == "0":
183
+ print(f"Stage 0: Processing first half [0:{midpoint}]")
184
+ preprocessed_dataset = "temp_preprocessed_0"
185
+ start_idx = 0
186
+ end_idx = midpoint
187
+ final_output = f"{output_dataset}_part0"
188
+ else:
189
+ print(f"Stage 1: Processing second half [{midpoint}:{total_dataset_size}]")
190
+ preprocessed_dataset = "temp_preprocessed_1"
191
+ start_idx = midpoint
192
+ end_idx = total_dataset_size
193
+ final_output = input_dataset
194
 
195
  if not os.path.exists(preprocessed_dataset):
196
+ run_preprocessing(input_dataset, preprocessed_dataset, start_idx=start_idx, end_idx=end_idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  print("Loading preprocessed dataset...")
199
  ds = datasets.load_from_disk(preprocessed_dataset)
 
211
  for i in range(num_gpus):
212
  start = i * shard_size
213
  end = start + shard_size if i < num_gpus - 1 else total_size
214
+ output_file = f"temp_shard_{init_stage}_{i}"
215
  temp_files.append(output_file)
216
 
217
  p = mp.Process(
 
246
  shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files]
247
  final_ds = datasets.concatenate_datasets(shards)
248
 
249
+ print(f"Final dataset size: {len(final_ds)}")
250
+
251
+ if init_stage == "0":
252
+ print(f"Pushing first half to {final_output}...")
253
+ final_ds.push_to_hub(final_output, create_pr=False)
254
  else:
255
+ print("Loading first half from hub...")
256
+ first_half = datasets.load_dataset(f"{output_dataset}_part0", split="train")
257
+
258
+ print("Concatenating both halves...")
259
+ complete_ds = datasets.concatenate_datasets([first_half, final_ds])
260
+
261
+ print(f"Complete dataset size: {len(complete_ds)}")
262
+ print(f"Pushing complete dataset to {final_output} with PR...")
263
+ complete_ds.push_to_hub(final_output, create_pr=True)
264
 
265
  print("Cleaning up temporary files...")
266
  for f in temp_files:
 
273
 
274
 
275
  if __name__ == "__main__":
276
+ main()