Spaces:
Sleeping
Sleeping
| """VRAM Calculator for HuggingFace Models""" | |
| from __future__ import annotations | |
| import gradio as gr | |
| from huggingface_hub import HfApi, hf_hub_download | |
| import json | |
| from functools import lru_cache | |
| api = HfApi() | |
| GPU_SPECS = { | |
| "RTX 3080": (10, 0), | |
| "RTX 3090": (24, 0), | |
| "RTX 4080": (16, 0), | |
| "RTX 4090": (24, 0), | |
| "RTX 5090": (32, 0), | |
| "M2 Ultra": (192, 0), | |
| "M3 Max": (128, 0), | |
| "M4 Max": (128, 0), | |
| "RTX A6000": (48, 0), | |
| "L40S": (48, 1.00), | |
| "A10G": (24, 1.00), | |
| "L4": (24, 0.70), | |
| "A100 40GB": (40, 3.00), | |
| "A100 80GB": (80, 5.00), | |
| "H100 80GB": (80, 8.00), | |
| } | |
| DTYPE_BYTES = { | |
| "F32": 4, "float32": 4, | |
| "F16": 2, "float16": 2, | |
| "BF16": 2, "bfloat16": 2, | |
| "I8": 1, "int8": 1, | |
| "U8": 1, "uint8": 1, | |
| } | |
| FRAMEWORKS = { | |
| "None (PyTorch)": 1.20, | |
| "vLLM": 1.10, | |
| "TGI": 1.15, | |
| "llama.cpp": 1.05, | |
| "Ollama": 1.08, | |
| } | |
| def bytes_to_gb(b): | |
| return b / (1024 ** 3) | |
| def fetch_model_info(model_id): | |
| try: | |
| return api.model_info(model_id, files_metadata=True) | |
| except Exception: | |
| return None | |
| def fetch_config(model_id): | |
| try: | |
| path = hf_hub_download(model_id, "config.json") | |
| with open(path) as f: | |
| return json.load(f) | |
| except Exception: | |
| return {} | |
| def get_params(info): | |
| if info and hasattr(info, 'safetensors') and info.safetensors: | |
| params = info.safetensors.total | |
| dtypes = info.safetensors.parameters | |
| if dtypes: | |
| dtype = max(dtypes, key=dtypes.get) | |
| return params, dtype | |
| return 0, "F16" | |
| def calculate(model_id, context, batch, mode, framework, num_gpus, lora_rank): | |
| """Main calculation function""" | |
| try: | |
| if not model_id or not model_id.strip(): | |
| return "Enter a model ID (e.g., meta-llama/Llama-3.1-8B)" | |
| model_id = model_id.strip() | |
| if "/" not in model_id: | |
| return "Model ID format: organization/model-name" | |
| info = fetch_model_info(model_id) | |
| if not info: | |
| return "Could not fetch model: " + model_id | |
| config = fetch_config(model_id) | |
| params, dtype = get_params(info) | |
| if params == 0: | |
| return "Could not read parameters (model may use .bin format)" | |
| dtype_bytes = DTYPE_BYTES.get(dtype, 2) | |
| params_b = params / 1e9 | |
| weights_gb = bytes_to_gb(params * dtype_bytes) | |
| layers = config.get("num_hidden_layers", config.get("n_layer", 32)) | |
| kv_heads = config.get("num_key_value_heads", config.get("num_attention_heads", 32)) | |
| head_dim = config.get("head_dim", 128) | |
| if not head_dim: | |
| hidden = config.get("hidden_size", 4096) | |
| heads = config.get("num_attention_heads", 32) | |
| head_dim = hidden // heads if heads else 128 | |
| kv_bytes = 2 * layers * batch * context * kv_heads * head_dim * dtype_bytes | |
| kv_gb = bytes_to_gb(kv_bytes) | |
| out = [] | |
| out.append("## " + model_id) | |
| out.append("**" + str(round(params_b, 1)) + "B parameters** | " + dtype + " | " + str(layers) + " layers") | |
| out.append("") | |
| if mode == "Training (Full)": | |
| grad_gb = weights_gb | |
| opt_gb = bytes_to_gb(params * 8) | |
| act_gb = weights_gb * 2 * batch | |
| total = weights_gb + grad_gb + opt_gb + act_gb | |
| out.append("### Training Memory") | |
| out.append("- Weights: " + str(round(weights_gb, 1)) + " GB") | |
| out.append("- Gradients: " + str(round(grad_gb, 1)) + " GB") | |
| out.append("- Optimizer: " + str(round(opt_gb, 1)) + " GB") | |
| out.append("- Activations: " + str(round(act_gb, 1)) + " GB") | |
| elif mode == "LoRA": | |
| base = weights_gb | |
| lora_params = int(params * lora_rank * 0.0001) | |
| lora_gb = bytes_to_gb(lora_params * dtype_bytes) | |
| act_gb = base * 0.3 | |
| total = base + lora_gb + act_gb | |
| out.append("### LoRA Memory") | |
| out.append("- Base (frozen): " + str(round(base, 1)) + " GB") | |
| out.append("- LoRA adapters: " + str(round(lora_gb, 2)) + " GB") | |
| out.append("- Activations: " + str(round(act_gb, 1)) + " GB") | |
| elif mode == "QLoRA": | |
| base = bytes_to_gb(params * 0.5) | |
| lora_params = int(params * lora_rank * 0.0001) | |
| lora_gb = bytes_to_gb(lora_params * dtype_bytes) | |
| act_gb = base * 0.3 | |
| total = base + lora_gb + act_gb | |
| out.append("### QLoRA Memory") | |
| out.append("- Base (4-bit): " + str(round(base, 1)) + " GB") | |
| out.append("- LoRA adapters: " + str(round(lora_gb, 2)) + " GB") | |
| out.append("- Activations: " + str(round(act_gb, 1)) + " GB") | |
| else: | |
| overhead = FRAMEWORKS.get(framework, 1.15) | |
| extra = (weights_gb + kv_gb) * (overhead - 1) | |
| total = weights_gb + kv_gb + extra | |
| out.append("### Inference Memory") | |
| out.append("- Weights: " + str(round(weights_gb, 1)) + " GB") | |
| out.append("- KV Cache: " + str(round(kv_gb, 1)) + " GB") | |
| out.append("- Overhead (" + framework + "): " + str(round(extra, 1)) + " GB") | |
| if num_gpus > 1: | |
| per_gpu = total / num_gpus * 1.05 | |
| out.append("") | |
| out.append("**Multi-GPU (" + str(num_gpus) + "x):** " + str(round(per_gpu, 1)) + " GB/GPU") | |
| effective = per_gpu | |
| else: | |
| effective = total | |
| out.append("") | |
| out.append("## Total: " + str(round(total, 1)) + " GB") | |
| out.append("") | |
| out.append("### GPU Options") | |
| out.append("| GPU | VRAM | Fits | Headroom |") | |
| out.append("|-----|------|------|----------|") | |
| for gpu, (vram, cost) in GPU_SPECS.items(): | |
| fits = "Yes" if vram >= effective else "No" | |
| hr = vram - effective | |
| sign = "+" if hr >= 0 else "" | |
| out.append("| " + gpu + " | " + str(vram) + "GB | " + fits + " | " + sign + str(round(hr, 1)) + "GB |") | |
| if effective > 24: | |
| out.append("") | |
| out.append("### Quantization to fit 24GB") | |
| out.append("| Method | Size |") | |
| out.append("|--------|------|") | |
| for name, mult in [("INT8", 1.0), ("4-bit", 0.5), ("3-bit", 0.375)]: | |
| size = bytes_to_gb(params * mult) * 1.1 | |
| out.append("| " + name + " | " + str(round(size, 1)) + "GB |") | |
| costs = [(gpu, cost) for gpu, (vram, cost) in GPU_SPECS.items() if vram >= effective and cost > 0] | |
| if costs: | |
| costs.sort(key=lambda x: x[1]) | |
| out.append("") | |
| out.append("### Cloud Costs (8hr/day)") | |
| out.append("| GPU | $/hr | $/month |") | |
| out.append("|-----|------|---------|") | |
| for gpu, cost in costs[:4]: | |
| out.append("| " + gpu + " | $" + str(round(cost, 2)) + " | $" + str(int(cost * 176)) + " |") | |
| return "\n".join(out) | |
| except Exception as e: | |
| return "Error: " + str(e) | |
| def compare(models_text, context): | |
| """Compare multiple models""" | |
| try: | |
| if not models_text: | |
| return "Enter model IDs, one per line" | |
| models = [m.strip() for m in models_text.strip().split("\n") if m.strip()] | |
| if len(models) < 2: | |
| return "Need at least 2 models" | |
| out = [] | |
| out.append("## Comparison") | |
| out.append("| Model | Params | Inference | Training | QLoRA |") | |
| out.append("|-------|--------|-----------|----------|-------|") | |
| for mid in models[:5]: | |
| try: | |
| info = fetch_model_info(mid) | |
| config = fetch_config(mid) | |
| params, dtype = get_params(info) | |
| if params == 0: | |
| out.append("| " + mid + " | Error | - | - | - |") | |
| continue | |
| db = DTYPE_BYTES.get(dtype, 2) | |
| w = bytes_to_gb(params * db) | |
| layers = config.get("num_hidden_layers", 32) | |
| kv_heads = config.get("num_key_value_heads", 32) | |
| kv = bytes_to_gb(2 * layers * context * kv_heads * 128 * db) | |
| inf = w + kv | |
| train = w * 4 + w * 2 | |
| qlora = bytes_to_gb(params * 0.5) * 1.5 | |
| name = mid.split("/")[-1][:20] | |
| out.append("| " + name + " | " + str(round(params / 1e9, 1)) + "B | " + str(round(inf, 1)) + "GB | " + str(round(train, 1)) + "GB | " + str(round(qlora, 1)) + "GB |") | |
| except Exception: | |
| out.append("| " + mid + " | Error | - | - | - |") | |
| return "\n".join(out) | |
| except Exception as e: | |
| return "Error: " + str(e) | |
| # Build the interface | |
| with gr.Blocks(title="VRAM Calculator") as demo: | |
| gr.Markdown("# VRAM Calculator for LLMs") | |
| gr.Markdown("Estimate VRAM requirements for HuggingFace models") | |
| with gr.Tabs(): | |
| with gr.TabItem("Calculator"): | |
| model_in = gr.Textbox( | |
| label="Model ID", | |
| placeholder="meta-llama/Llama-3.1-8B", | |
| info="Enter a HuggingFace model ID" | |
| ) | |
| mode_in = gr.Radio( | |
| choices=["Inference", "Training (Full)", "LoRA", "QLoRA"], | |
| value="Inference", | |
| label="Mode" | |
| ) | |
| with gr.Row(): | |
| ctx_in = gr.Slider( | |
| minimum=512, | |
| maximum=131072, | |
| value=4096, | |
| step=512, | |
| label="Context Length" | |
| ) | |
| batch_in = gr.Slider( | |
| minimum=1, | |
| maximum=64, | |
| value=1, | |
| step=1, | |
| label="Batch Size" | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| framework_in = gr.Dropdown( | |
| choices=list(FRAMEWORKS.keys()), | |
| value="vLLM", | |
| label="Framework" | |
| ) | |
| gpus_in = gr.Slider( | |
| minimum=1, | |
| maximum=8, | |
| value=1, | |
| step=1, | |
| label="Number of GPUs" | |
| ) | |
| lora_in = gr.Slider( | |
| minimum=4, | |
| maximum=128, | |
| value=16, | |
| step=4, | |
| label="LoRA Rank" | |
| ) | |
| calc_btn = gr.Button("Calculate", variant="primary") | |
| output = gr.Markdown() | |
| calc_btn.click( | |
| fn=calculate, | |
| inputs=[model_in, ctx_in, batch_in, mode_in, framework_in, gpus_in, lora_in], | |
| outputs=output | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["meta-llama/Llama-3.1-8B"], | |
| ["meta-llama/Llama-3.1-70B"], | |
| ["mistralai/Mistral-7B-v0.1"], | |
| ], | |
| inputs=[model_in], | |
| label="Example Models" | |
| ) | |
| with gr.TabItem("Compare Models"): | |
| cmp_in = gr.Textbox( | |
| label="Models (one per line)", | |
| lines=4, | |
| placeholder="meta-llama/Llama-3.1-8B\nmistralai/Mistral-7B-v0.1" | |
| ) | |
| cmp_ctx = gr.Slider( | |
| minimum=512, | |
| maximum=131072, | |
| value=4096, | |
| step=512, | |
| label="Context Length" | |
| ) | |
| cmp_btn = gr.Button("Compare", variant="primary") | |
| cmp_out = gr.Markdown() | |
| cmp_btn.click( | |
| fn=compare, | |
| inputs=[cmp_in, cmp_ctx], | |
| outputs=cmp_out | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("*Estimates are approximate. Actual usage may vary.*") | |
| if __name__ == "__main__": | |
| demo.launch() | |