import torch from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig import datasets from datasets import Dataset from typing import cast import os import shutil import multiprocessing as mp from PIL import Image def load_model(model_name, device_id=0): bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) processor = AutoProcessor.from_pretrained(model_name) processor.tokenizer.padding_side = "left" model = AutoModelForImageTextToText.from_pretrained( model_name, quantization_config=bnb_config, dtype=torch.bfloat16, device_map={"": device_id}, attn_implementation="flash_attention_2", ) return processor, model def getTemplate(processor): msg = [ { "role": "user", "content": [ {"type": "image"}, { "type": "text", "text": "Describe the image concisely, and skip mentioning that it's illustrated or from anime.", }, ], } ] return processor.apply_chat_template( msg, add_generation_prompt=True, tokenize=False ) def preprocess_example_batch(examples, text): processed_images = [] for image in examples["image"]: if isinstance(image, Image.Image): if image.mode != "RGB": image = image.convert("RGB") processed_images.append(image) else: raise ValueError("Image must be a PIL Image") return { "image": processed_images, "text": [text] * len(processed_images), } def run_preprocessing(input_dataset, output_dir, num_proc=32, batch_size=100, start_idx=0, end_idx=None): print("Loading dataset for preprocessing...") ds = datasets.load_dataset(input_dataset, split="train") if end_idx is None: end_idx = len(ds) print(f"Selecting range [{start_idx}:{end_idx}]...") ds = ds.select(range(start_idx, end_idx)) print("Loading processor...") processor = AutoProcessor.from_pretrained("datalab-to/chandra") text = getTemplate(processor) print("Running preprocessing...") processed_ds = ds.map( lambda ex: preprocess_example_batch(ex, text), remove_columns=[col for col in ds.column_names if col not in ["image", "text"]], num_proc=num_proc, batched=True, batch_size=batch_size, ) print(f"Saving preprocessed dataset to {output_dir}...") processed_ds.save_to_disk(output_dir) print("Preprocessing done.") def caption_batch(batch, processor, model): images = batch["image"] texts = batch["text"] inputs = processor(text=texts, images=images, return_tensors="pt", padding=True) inputs = { k: v.pin_memory().to(model.device, non_blocking=True) for k, v in inputs.items() } with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): generated = model.generate( **inputs, max_new_tokens=128, do_sample=False, ) decoded = processor.batch_decode(generated, skip_special_tokens=False) captions = [] special_tokens = set(processor.tokenizer.all_special_tokens) for d in decoded: if "<|im_start|>assistant" in d: d = d.split("<|im_start|>assistant")[-1] for token in special_tokens: d = d.replace(token, "") d = d.strip() captions.append(d) return { "text": captions, } def process_shard( gpu_id, start, end, model_name, batch_size, input_dataset, output_file ): try: torch.cuda.set_device(gpu_id) print(f"[GPU {gpu_id}] Loading model...", flush=True) processor, model = load_model(model_name, gpu_id) print(f"[GPU {gpu_id}] Loading data shard [{start}:{end}]...", flush=True) loaded = datasets.load_from_disk(input_dataset).select(range(start, end)) shard = cast(Dataset, loaded) print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True) result = shard.map( lambda batch: caption_batch(batch, processor, model), batched=True, batch_size=batch_size, remove_columns=["text"], ) print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True) result.save_to_disk(output_file) print(f"[GPU {gpu_id}] Done!", flush=True) return output_file except Exception as e: print(f"[GPU {gpu_id}] Error: {e}", flush=True) raise def main(): mp.set_start_method("spawn", force=True) init_stage = os.environ.get("INIT", "0") input_dataset = "none-yet/anime-captions" output_dataset = "nroggendorff/anime-captions" model_name = "datalab-to/chandra" batch_size = 20 print(f"Running stage INIT={init_stage}") full_ds = datasets.load_dataset(input_dataset, split="train") total_dataset_size = len(full_ds) midpoint = total_dataset_size // 2 if init_stage == "0": print(f"Stage 0: Processing first half [0:{midpoint}]") preprocessed_dataset = "temp_preprocessed_0" start_idx = 0 end_idx = midpoint final_output = f"{output_dataset}_part0" else: print(f"Stage 1: Processing second half [{midpoint}:{total_dataset_size}]") preprocessed_dataset = "temp_preprocessed_1" start_idx = midpoint end_idx = total_dataset_size final_output = input_dataset if not os.path.exists(preprocessed_dataset): run_preprocessing(input_dataset, preprocessed_dataset, start_idx=start_idx, end_idx=end_idx) print("Loading preprocessed dataset...") ds = datasets.load_from_disk(preprocessed_dataset) num_gpus = torch.cuda.device_count() total_size = len(ds) shard_size = total_size // num_gpus print(f"Dataset size: {total_size}") print(f"Using {num_gpus} GPUs") print(f"Shard size: {shard_size}") processes = [] temp_files = [] for i in range(num_gpus): start = i * shard_size end = start + shard_size if i < num_gpus - 1 else total_size output_file = f"temp_shard_{init_stage}_{i}" temp_files.append(output_file) p = mp.Process( target=process_shard, args=( i, start, end, model_name, batch_size, preprocessed_dataset, output_file, ), ) p.start() processes.append(p) for p in processes: p.join() if p.exitcode != 0: print(f"\nProcess failed with exit code {p.exitcode}", flush=True) print("Terminating all processes...", flush=True) for proc in processes: if proc.is_alive(): proc.terminate() for proc in processes: proc.join() raise RuntimeError(f"At least one process failed") print("\nAll processes completed. Loading and concatenating results...") shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files] final_ds = datasets.concatenate_datasets(shards) print(f"Final dataset size: {len(final_ds)}") if init_stage == "0": print(f"Pushing first half to {final_output}...") final_ds.push_to_hub(final_output, create_pr=False) else: print("Loading first half from hub...") first_half = datasets.load_dataset(f"{output_dataset}_part0", split="train") print("Concatenating both halves...") complete_ds = datasets.concatenate_datasets([first_half, final_ds]) print(f"Complete dataset size: {len(complete_ds)}") print(f"Pushing complete dataset to {final_output} with PR...") complete_ds.push_to_hub(final_output, create_pr=True) print("Cleaning up temporary files...") for f in temp_files: if os.path.exists(f): shutil.rmtree(f) if os.path.exists(preprocessed_dataset): shutil.rmtree(preprocessed_dataset) print("Done!") if __name__ == "__main__": main()