""" Copyright (c) 2024 by SageAttention team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch import triton import triton.language as tl @triton.jit def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, stride_iz, stride_ih, stride_in, stride_oz, stride_oh, stride_on, stride_sz, stride_sh, C: tl.constexpr, BLK: tl.constexpr): off_blk = tl.program_id(0) // 8 off_tld = tl.program_id(0) % 8 off_h = tl.program_id(1) off_b = tl.program_id(2) offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld offs_k = tl.arange(0, C) input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld x = tl.load(input_ptrs, mask=offs_n[:, None] < L) x = x.to(tl.float32) scale = tl.max(tl.abs(x)) / 127. + 0.0000001 x_int8 = x / scale x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) x_int8 = x_int8.to(tl.int8) tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) tl.store(scale_ptrs, scale) @triton.jit def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, stride_iz, stride_ih, stride_in, stride_oz, stride_oh, stride_on, stride_sz, stride_sh, C: tl.constexpr, BLK: tl.constexpr): off_blk = tl.program_id(0) // 4 off_tld = tl.program_id(0) % 4 off_h = tl.program_id(1) off_b = tl.program_id(2) # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 # offs_k = tl.arange(0, C) # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) # x = x.to(tl.float32) # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 # x_int8 = x / scale # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) # x_int8 = x_int8.to(tl.int8) # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) # tl.store(scale_ptrs, scale) offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 offs_k = tl.arange(0, C) input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) x0 = x0.to(tl.float32) x1 = x1.to(tl.float32) scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 x0_int8 = x0 / scale x1_int8 = x1 / scale x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) x0_int8 = x0_int8.to(tl.int8) x1_int8 = x1_int8.to(tl.int8) tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) tl.store(scale_ptrs, scale) @triton.jit def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, stride_iz, stride_ih, stride_in, stride_oz, stride_oh, stride_on, stride_sz, stride_sh, C: tl.constexpr, BLK: tl.constexpr): off_blk = tl.program_id(0) // 8 off_tld = tl.program_id(0) % 8 off_h = tl.program_id(1) off_b = tl.program_id(2) offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld offs_k = tl.arange(0, C) input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld x = tl.load(input_ptrs, mask=offs_n[:, None] < L) x = x.to(tl.float32) scale = tl.max(tl.abs(x)) / 7. + 0.0000001 x_int8 = x / scale x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) x_int8 = x_int8.to(tl.int8) tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) tl.store(scale_ptrs, scale) @triton.jit def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, stride_iz, stride_ih, stride_in, stride_oz, stride_oh, stride_on, stride_sz, stride_sh, C: tl.constexpr, BLK: tl.constexpr): off_blk = tl.program_id(0) // 4 off_tld = tl.program_id(0) % 4 off_h = tl.program_id(1) off_b = tl.program_id(2) offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 offs_k = tl.arange(0, C) input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld x = tl.load(input_ptrs, mask=offs_n[:, None] < L) x = x.to(tl.float32) scale = tl.max(tl.abs(x)) / 7. + 0.0000001 x_int8 = x / scale x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) x_int8 = x_int8.to(tl.int8) tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) tl.store(scale_ptrs, scale) def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) if km is not None: k = k - km if tensor_layout == "HND": b, h_qo, qo_len, head_dim = q.shape _, h_kv, kv_len, _ = k.shape stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) elif tensor_layout == "NHD": b, qo_len, h_qo, head_dim = q.shape _, kv_len, h_kv, _ = k.shape stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) else: raise ValueError(f"Unknown tensor layout: {tensor_layout}") q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) if sm_scale is None: sm_scale = head_dim**-0.5 grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) quant_query_per_thread_int8_kernel[grid]( q, q_int8, q_scale, qo_len, stride_bz_q, stride_h_q, stride_seq_q, stride_bz_qo, stride_h_qo, stride_seq_qo, q_scale.stride(0), q_scale.stride(1), C=head_dim, BLK=WARPQ ) grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) quant_key_per_thread_int8_kernel[grid]( k, k_int8, k_scale, kv_len, stride_bz_k, stride_h_k, stride_seq_k, stride_bz_ko, stride_h_ko, stride_seq_ko, k_scale.stride(0), k_scale.stride(1), C=head_dim, BLK=WARPK ) return q_int8, q_scale, k_int8, k_scale