|
|
import math |
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
import sage_attention as sa |
|
|
|
|
|
|
|
|
cuda_available = torch.cuda.is_available() |
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not cuda_available, reason="CUDA is required") |
|
|
@pytest.mark.parametrize("tensor_layout", ["HND", "NHD"]) |
|
|
def test_per_block_int8_shapes_and_types(tensor_layout): |
|
|
device = "cuda" |
|
|
dtype = torch.float16 |
|
|
|
|
|
if tensor_layout == "HND": |
|
|
q = torch.randn(2, 4, 129, 128, dtype=dtype, device=device) |
|
|
k = torch.randn(2, 4, 257, 128, dtype=dtype, device=device) |
|
|
expected_q_scale_shape = (2, 4, math.ceil(129 / 128)) |
|
|
expected_k_scale_shape = (2, 4, math.ceil(257 / 64)) |
|
|
else: |
|
|
q = torch.randn(2, 129, 4, 128, dtype=dtype, device=device) |
|
|
k = torch.randn(2, 257, 4, 128, dtype=dtype, device=device) |
|
|
expected_q_scale_shape = (2, 4, math.ceil(129 / 128)) |
|
|
expected_k_scale_shape = (2, 4, math.ceil(257 / 64)) |
|
|
|
|
|
km = ( |
|
|
torch.randn(2, 4, 128, dtype=dtype, device=device) |
|
|
if tensor_layout == "HND" |
|
|
else torch.randn(2, 4, 128, dtype=dtype, device=device) |
|
|
) |
|
|
|
|
|
q_int8, q_scale, k_int8, k_scale = sa.per_block_int8( |
|
|
q, k, km, tensor_layout=tensor_layout |
|
|
) |
|
|
|
|
|
assert q_int8.shape == q.shape and q_int8.dtype == torch.int8 |
|
|
assert k_int8.shape == k.shape and k_int8.dtype == torch.int8 |
|
|
assert q_scale.shape == expected_q_scale_shape and q_scale.dtype == torch.float32 |
|
|
assert k_scale.shape == expected_k_scale_shape and k_scale.dtype == torch.float32 |
|
|
assert q_int8.device == q.device == k.device == q_scale.device == k_scale.device |
|
|
assert torch.isfinite(q_scale).all() |
|
|
assert torch.isfinite(k_scale).all() |
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not cuda_available, reason="CUDA is required") |
|
|
@pytest.mark.parametrize("tensor_layout", ["HND", "NHD"]) |
|
|
@pytest.mark.parametrize("head_dim", [64, 128]) |
|
|
def test_per_warp_int8_shapes_and_types(tensor_layout, head_dim): |
|
|
device = "cuda" |
|
|
dtype = torch.float16 |
|
|
|
|
|
if tensor_layout == "HND": |
|
|
q = torch.randn(1, 2, 130, head_dim, dtype=dtype, device=device) |
|
|
k = torch.randn(1, 2, 70, head_dim, dtype=dtype, device=device) |
|
|
expected_q_scale_shape = ( |
|
|
1, |
|
|
2, |
|
|
math.ceil(130 / 128) * (128 // (16 if head_dim == 128 else 32)), |
|
|
) |
|
|
expected_k_scale_shape = (1, 2, math.ceil(70 / 64)) |
|
|
else: |
|
|
q = torch.randn(1, 130, 2, head_dim, dtype=dtype, device=device) |
|
|
k = torch.randn(1, 70, 2, head_dim, dtype=dtype, device=device) |
|
|
expected_q_scale_shape = ( |
|
|
1, |
|
|
2, |
|
|
math.ceil(130 / 128) * (128 // (16 if head_dim == 128 else 32)), |
|
|
) |
|
|
expected_k_scale_shape = (1, 2, math.ceil(70 / 64)) |
|
|
|
|
|
q_int8, q_scale, k_int8, k_scale = sa.per_warp_int8( |
|
|
q, |
|
|
k, |
|
|
tensor_layout=tensor_layout, |
|
|
BLKQ=128, |
|
|
WARPQ=(16 if head_dim == 128 else 32), |
|
|
BLKK=64, |
|
|
) |
|
|
|
|
|
assert q_int8.shape == q.shape and q_int8.dtype == torch.int8 |
|
|
assert k_int8.shape == k.shape and k_int8.dtype == torch.int8 |
|
|
assert q_scale.shape == expected_q_scale_shape and q_scale.dtype == torch.float32 |
|
|
assert k_scale.shape == expected_k_scale_shape and k_scale.dtype == torch.float32 |
|
|
assert torch.isfinite(q_scale).all() |
|
|
assert torch.isfinite(k_scale).all() |
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not cuda_available, reason="CUDA is required") |
|
|
@pytest.mark.parametrize("tensor_layout", ["HND", "NHD"]) |
|
|
def test_sub_mean_properties(tensor_layout): |
|
|
device = "cuda" |
|
|
dtype = torch.float16 |
|
|
|
|
|
if tensor_layout == "HND": |
|
|
v = torch.randn(2, 3, 65, 128, dtype=dtype, device=device) |
|
|
seq_dim = 2 |
|
|
nh_dim = 1 |
|
|
else: |
|
|
v = torch.randn(2, 65, 3, 128, dtype=dtype, device=device) |
|
|
seq_dim = 1 |
|
|
nh_dim = 2 |
|
|
|
|
|
v_smoothed, vm = sa.sub_mean(v, tensor_layout=tensor_layout) |
|
|
|
|
|
assert v_smoothed.shape == v.shape and v_smoothed.dtype == torch.float16 |
|
|
assert vm.shape == (v.size(0), v.size(nh_dim), v.size(-1)) and vm.dtype == v.dtype |
|
|
|
|
|
mean_after = v_smoothed.mean(dim=seq_dim) |
|
|
assert torch.isfinite(mean_after).all() |
|
|
assert (mean_after.abs() < 1e-1).all() |
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not cuda_available, reason="CUDA is required") |
|
|
@pytest.mark.parametrize("tensor_layout", ["HND", "NHD"]) |
|
|
@pytest.mark.parametrize("smooth_v", [True, False]) |
|
|
def test_per_channel_fp8_shapes_and_outputs(tensor_layout, smooth_v): |
|
|
device = "cuda" |
|
|
dtype = torch.float16 |
|
|
|
|
|
if tensor_layout == "HND": |
|
|
v = torch.randn(2, 3, 77, 128, dtype=dtype, device=device) |
|
|
kv_len = v.size(2) |
|
|
else: |
|
|
v = torch.randn(2, 77, 3, 128, dtype=dtype, device=device) |
|
|
kv_len = v.size(1) |
|
|
|
|
|
v_fp8, v_scale, vm = sa.per_channel_fp8( |
|
|
v, tensor_layout=tensor_layout, smooth_v=smooth_v |
|
|
) |
|
|
|
|
|
assert v_fp8.dtype == torch.float8_e4m3fn |
|
|
assert v_scale.shape == (2, 3, 128) |
|
|
if smooth_v: |
|
|
assert vm is not None and vm.shape == (2, 3, 128) and vm.dtype == torch.float32 |
|
|
else: |
|
|
assert vm is None |
|
|
|
|
|
|
|
|
padded_len = ((kv_len + 63) // 64) * 64 |
|
|
if tensor_layout == "HND": |
|
|
assert v_fp8.shape == (2, 3, 128, padded_len) |
|
|
else: |
|
|
assert v_fp8.shape == (2, 128, 3, padded_len) |
|
|
assert torch.isfinite(v_scale).all() |
|
|
|