Spaces:
Paused
Paused
Update train.py
Browse files
train.py
CHANGED
|
@@ -67,9 +67,15 @@ def preprocess_example_batch(examples, text):
|
|
| 67 |
}
|
| 68 |
|
| 69 |
|
| 70 |
-
def run_preprocessing(input_dataset, output_dir, num_proc=32, batch_size=100):
|
| 71 |
print("Loading dataset for preprocessing...")
|
| 72 |
ds = datasets.load_dataset(input_dataset, split="train")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
print("Loading processor...")
|
| 75 |
processor = AutoProcessor.from_pretrained("datalab-to/chandra")
|
|
@@ -99,7 +105,7 @@ def caption_batch(batch, processor, model):
|
|
| 99 |
k: v.pin_memory().to(model.device, non_blocking=True) for k, v in inputs.items()
|
| 100 |
}
|
| 101 |
|
| 102 |
-
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
|
| 103 |
generated = model.generate(
|
| 104 |
**inputs,
|
| 105 |
max_new_tokens=128,
|
|
@@ -160,36 +166,34 @@ def process_shard(
|
|
| 160 |
def main():
|
| 161 |
mp.set_start_method("spawn", force=True)
|
| 162 |
|
|
|
|
|
|
|
| 163 |
input_dataset = "none-yet/anime-captions"
|
| 164 |
-
preprocessed_dataset = "temp_preprocessed"
|
| 165 |
output_dataset = "nroggendorff/anime-captions"
|
| 166 |
model_name = "datalab-to/chandra"
|
| 167 |
batch_size = 20
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
if not os.path.exists(preprocessed_dataset):
|
| 174 |
-
|
| 175 |
-
ds_full = datasets.load_dataset(input_dataset, split="train")
|
| 176 |
-
total_size = len(ds_full)
|
| 177 |
-
midpoint = total_size // 2
|
| 178 |
-
|
| 179 |
-
if is_first_run:
|
| 180 |
-
ds_to_process = ds_full.select(range(0, midpoint))
|
| 181 |
-
else:
|
| 182 |
-
ds_to_process = ds_full.select(range(midpoint, total_size))
|
| 183 |
-
|
| 184 |
-
print(
|
| 185 |
-
f"[{'First' if is_first_run else 'Second'} Run] Saving selected shard to disk..."
|
| 186 |
-
)
|
| 187 |
-
ds_to_process.save_to_disk("temp_input_shard")
|
| 188 |
-
|
| 189 |
-
run_preprocessing("temp_input_shard", preprocessed_dataset)
|
| 190 |
-
|
| 191 |
-
# Clean up temp input shard
|
| 192 |
-
shutil.rmtree("temp_input_shard")
|
| 193 |
|
| 194 |
print("Loading preprocessed dataset...")
|
| 195 |
ds = datasets.load_from_disk(preprocessed_dataset)
|
|
@@ -207,7 +211,7 @@ def main():
|
|
| 207 |
for i in range(num_gpus):
|
| 208 |
start = i * shard_size
|
| 209 |
end = start + shard_size if i < num_gpus - 1 else total_size
|
| 210 |
-
output_file = f"temp_shard_{i}"
|
| 211 |
temp_files.append(output_file)
|
| 212 |
|
| 213 |
p = mp.Process(
|
|
@@ -242,16 +246,21 @@ def main():
|
|
| 242 |
shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files]
|
| 243 |
final_ds = datasets.concatenate_datasets(shards)
|
| 244 |
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
| 248 |
else:
|
| 249 |
-
print("
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
print(
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
print("Cleaning up temporary files...")
|
| 257 |
for f in temp_files:
|
|
@@ -264,4 +273,4 @@ def main():
|
|
| 264 |
|
| 265 |
|
| 266 |
if __name__ == "__main__":
|
| 267 |
-
main()
|
|
|
|
| 67 |
}
|
| 68 |
|
| 69 |
|
| 70 |
+
def run_preprocessing(input_dataset, output_dir, num_proc=32, batch_size=100, start_idx=0, end_idx=None):
|
| 71 |
print("Loading dataset for preprocessing...")
|
| 72 |
ds = datasets.load_dataset(input_dataset, split="train")
|
| 73 |
+
|
| 74 |
+
if end_idx is None:
|
| 75 |
+
end_idx = len(ds)
|
| 76 |
+
|
| 77 |
+
print(f"Selecting range [{start_idx}:{end_idx}]...")
|
| 78 |
+
ds = ds.select(range(start_idx, end_idx))
|
| 79 |
|
| 80 |
print("Loading processor...")
|
| 81 |
processor = AutoProcessor.from_pretrained("datalab-to/chandra")
|
|
|
|
| 105 |
k: v.pin_memory().to(model.device, non_blocking=True) for k, v in inputs.items()
|
| 106 |
}
|
| 107 |
|
| 108 |
+
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
|
| 109 |
generated = model.generate(
|
| 110 |
**inputs,
|
| 111 |
max_new_tokens=128,
|
|
|
|
| 166 |
def main():
|
| 167 |
mp.set_start_method("spawn", force=True)
|
| 168 |
|
| 169 |
+
init_stage = os.environ.get("INIT", "0")
|
| 170 |
+
|
| 171 |
input_dataset = "none-yet/anime-captions"
|
|
|
|
| 172 |
output_dataset = "nroggendorff/anime-captions"
|
| 173 |
model_name = "datalab-to/chandra"
|
| 174 |
batch_size = 20
|
| 175 |
|
| 176 |
+
print(f"Running stage INIT={init_stage}")
|
| 177 |
+
|
| 178 |
+
full_ds = datasets.load_dataset(input_dataset, split="train")
|
| 179 |
+
total_dataset_size = len(full_ds)
|
| 180 |
+
midpoint = total_dataset_size // 2
|
| 181 |
+
|
| 182 |
+
if init_stage == "0":
|
| 183 |
+
print(f"Stage 0: Processing first half [0:{midpoint}]")
|
| 184 |
+
preprocessed_dataset = "temp_preprocessed_0"
|
| 185 |
+
start_idx = 0
|
| 186 |
+
end_idx = midpoint
|
| 187 |
+
final_output = f"{output_dataset}_part0"
|
| 188 |
+
else:
|
| 189 |
+
print(f"Stage 1: Processing second half [{midpoint}:{total_dataset_size}]")
|
| 190 |
+
preprocessed_dataset = "temp_preprocessed_1"
|
| 191 |
+
start_idx = midpoint
|
| 192 |
+
end_idx = total_dataset_size
|
| 193 |
+
final_output = input_dataset
|
| 194 |
|
| 195 |
if not os.path.exists(preprocessed_dataset):
|
| 196 |
+
run_preprocessing(input_dataset, preprocessed_dataset, start_idx=start_idx, end_idx=end_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
print("Loading preprocessed dataset...")
|
| 199 |
ds = datasets.load_from_disk(preprocessed_dataset)
|
|
|
|
| 211 |
for i in range(num_gpus):
|
| 212 |
start = i * shard_size
|
| 213 |
end = start + shard_size if i < num_gpus - 1 else total_size
|
| 214 |
+
output_file = f"temp_shard_{init_stage}_{i}"
|
| 215 |
temp_files.append(output_file)
|
| 216 |
|
| 217 |
p = mp.Process(
|
|
|
|
| 246 |
shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files]
|
| 247 |
final_ds = datasets.concatenate_datasets(shards)
|
| 248 |
|
| 249 |
+
print(f"Final dataset size: {len(final_ds)}")
|
| 250 |
+
|
| 251 |
+
if init_stage == "0":
|
| 252 |
+
print(f"Pushing first half to {final_output}...")
|
| 253 |
+
final_ds.push_to_hub(final_output, create_pr=False)
|
| 254 |
else:
|
| 255 |
+
print("Loading first half from hub...")
|
| 256 |
+
first_half = datasets.load_dataset(f"{output_dataset}_part0", split="train")
|
| 257 |
+
|
| 258 |
+
print("Concatenating both halves...")
|
| 259 |
+
complete_ds = datasets.concatenate_datasets([first_half, final_ds])
|
| 260 |
+
|
| 261 |
+
print(f"Complete dataset size: {len(complete_ds)}")
|
| 262 |
+
print(f"Pushing complete dataset to {final_output} with PR...")
|
| 263 |
+
complete_ds.push_to_hub(final_output, create_pr=True)
|
| 264 |
|
| 265 |
print("Cleaning up temporary files...")
|
| 266 |
for f in temp_files:
|
|
|
|
| 273 |
|
| 274 |
|
| 275 |
if __name__ == "__main__":
|
| 276 |
+
main()
|