|
|
from pathlib import Path |
|
|
import os |
|
|
import sys |
|
|
from safetensors.torch import load_file, save_file |
|
|
import torch |
|
|
|
|
|
PREFIX_OLD = "lm_head." |
|
|
PREFIX_NEW = "language_model.lm_head." |
|
|
|
|
|
|
|
|
def rename_keys(tensor_dict: dict) -> dict: |
|
|
"""Return a new dict with renamed keys.""" |
|
|
out = {} |
|
|
for name, tensor in tensor_dict.items(): |
|
|
if name.startswith(PREFIX_OLD): |
|
|
name = PREFIX_NEW + name[len(PREFIX_OLD):] |
|
|
out[name] = tensor |
|
|
return out |
|
|
|
|
|
|
|
|
def process_file(path: Path) -> None: |
|
|
print(f"Processing {path}") |
|
|
data = load_file(str(path), device="cpu") |
|
|
renamed = rename_keys(data) |
|
|
|
|
|
if renamed.keys() == data.keys(): |
|
|
print(" No keys needed renaming. Skipping.") |
|
|
return |
|
|
|
|
|
tmp_path = path.with_suffix(".safetensors.tmp") |
|
|
save_file(renamed, str(tmp_path)) |
|
|
os.replace(tmp_path, path) |
|
|
print(" Updated.") |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
files = sorted(Path(".").glob("model-*.safetensors")) |
|
|
if not files: |
|
|
print("No model-*.safetensors files found.", file=sys.stderr) |
|
|
sys.exit(1) |
|
|
|
|
|
for f in files: |
|
|
process_file(f) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|