| cmake_minimum_required(VERSION 3.26) | |
| project(sage_attention LANGUAGES CXX) | |
| set(TARGET_DEVICE "cuda" CACHE STRING "Target device backend for kernel") | |
| install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) | |
| include(FetchContent) | |
| file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists | |
| message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") | |
| set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") | |
| set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101") | |
| include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) | |
| if(DEFINED Python_EXECUTABLE) | |
| # Allow passing through the interpreter (e.g. from setup.py). | |
| find_package(Python COMPONENTS Development Development.SABIModule Interpreter) | |
| if (NOT Python_FOUND) | |
| message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") | |
| endif() | |
| else() | |
| find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter) | |
| endif() | |
| append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") | |
| find_package(Torch REQUIRED) | |
| if (NOT TARGET_DEVICE STREQUAL "cuda" AND | |
| NOT TARGET_DEVICE STREQUAL "rocm") | |
| return() | |
| endif() | |
| if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND | |
| CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) | |
| set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0+PTX") | |
| else() | |
| set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0+PTX") | |
| endif() | |
| if (NOT HIP_FOUND AND CUDA_FOUND) | |
| set(GPU_LANG "CUDA") | |
| elseif(HIP_FOUND) | |
| set(GPU_LANG "HIP") | |
| # Importing torch recognizes and sets up some HIP/ROCm configuration but does | |
| # not let cmake recognize .hip files. In order to get cmake to understand the | |
| # .hip extension automatically, HIP must be enabled explicitly. | |
| enable_language(HIP) | |
| else() | |
| message(FATAL_ERROR "Can't find CUDA or HIP installation.") | |
| endif() | |
| if(GPU_LANG STREQUAL "CUDA") | |
| clear_cuda_arches(CUDA_ARCH_FLAGS) | |
| extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}") | |
| message(STATUS "CUDA target architectures: ${CUDA_ARCHS}") | |
| # Filter the target architectures by the supported supported archs | |
| # since for some files we will build for all CUDA_ARCHS. | |
| cuda_archs_loose_intersection(CUDA_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}") | |
| message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}") | |
| if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA") | |
| list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}") | |
| endif() | |
| add_compile_definitions(CUDA_KERNEL) | |
| elseif(GPU_LANG STREQUAL "HIP") | |
| set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}") | |
| # TODO: remove this once we can set specific archs per source file set. | |
| override_gpu_arches(GPU_ARCHES | |
| ${GPU_LANG} | |
| "${${GPU_LANG}_SUPPORTED_ARCHS}") | |
| add_compile_definitions(ROCM_KERNEL) | |
| else() | |
| override_gpu_arches(GPU_ARCHES | |
| ${GPU_LANG} | |
| "${${GPU_LANG}_SUPPORTED_ARCHS}") | |
| endif() | |
| get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG}) | |
| list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS}) | |
| set(TORCH_sage_attention_SRC | |
| torch-ext/torch_binding.cpp torch-ext/torch_binding.h | |
| ) | |
| list(APPEND SRC "${TORCH_sage_attention_SRC}") | |
| set(_qattn_sm80_SRC | |
| "sage_attention/qattn/qk_int_sv_f16_cuda_sm80.cu" | |
| "sage_attention/qattn/attn_cuda_sm80.h" | |
| "sage_attention/qattn/attn_utils.cuh" | |
| ) | |
| # TODO: check if CLion support this: | |
| # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories | |
| set_source_files_properties( | |
| ${_qattn_sm80_SRC} | |
| PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.") | |
| if(GPU_LANG STREQUAL "CUDA") | |
| cuda_archs_loose_intersection(_qattn_sm80_ARCHS "8.0" "${CUDA_ARCHS}") | |
| message(STATUS "Capabilities for kernel _qattn_sm80: ${_qattn_sm80_ARCHS}") | |
| set_gencode_flags_for_srcs(SRCS "${_qattn_sm80_SRC}" CUDA_ARCHS "${_qattn_sm80_ARCHS}") | |
| foreach(_KERNEL_SRC ${_qattn_sm80_SRC}) | |
| if(_KERNEL_SRC MATCHES ".*\\.cu$") | |
| set_property( | |
| SOURCE ${_KERNEL_SRC} | |
| APPEND PROPERTY | |
| COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>" | |
| ) | |
| endif() | |
| endforeach() | |
| foreach(_KERNEL_SRC ${_qattn_sm80_SRC}) | |
| set_property( | |
| SOURCE ${_KERNEL_SRC} | |
| APPEND PROPERTY | |
| COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>" | |
| ) | |
| endforeach() | |
| list(APPEND SRC "${_qattn_sm80_SRC}") | |
| endif() | |
| set(_qattn_sm90_SRC | |
| "sage_attention/qattn/qk_int_sv_f8_cuda_sm90.cu" | |
| "sage_attention/qattn/attn_cuda_sm90.h" | |
| "sage_attention/qattn/attn_utils.cuh" | |
| "sage_attention/cuda_tensormap_shim.cuh" | |
| ) | |
| # TODO: check if CLion support this: | |
| # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories | |
| set_source_files_properties( | |
| ${_qattn_sm90_SRC} | |
| PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.") | |
| if(GPU_LANG STREQUAL "CUDA") | |
| cuda_archs_loose_intersection(_qattn_sm90_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") | |
| message(STATUS "Capabilities for kernel _qattn_sm90: ${_qattn_sm90_ARCHS}") | |
| set_gencode_flags_for_srcs(SRCS "${_qattn_sm90_SRC}" CUDA_ARCHS "${_qattn_sm90_ARCHS}") | |
| foreach(_KERNEL_SRC ${_qattn_sm90_SRC}) | |
| if(_KERNEL_SRC MATCHES ".*\\.cu$") | |
| set_property( | |
| SOURCE ${_KERNEL_SRC} | |
| APPEND PROPERTY | |
| COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>" | |
| ) | |
| endif() | |
| endforeach() | |
| foreach(_KERNEL_SRC ${_qattn_sm90_SRC}) | |
| set_property( | |
| SOURCE ${_KERNEL_SRC} | |
| APPEND PROPERTY | |
| COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>" | |
| ) | |
| endforeach() | |
| list(APPEND SRC "${_qattn_sm90_SRC}") | |
| endif() | |
| set(_qattn_sm89_SRC | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu" | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu" | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu" | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu" | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu" | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu" | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu" | |
| "sage_attention/qattn/attn_cuda_sm89.h" | |
| "sage_attention/qattn/qk_int_sv_f8_cuda_sm89.cuh" | |
| "sage_attention/qattn/attn_utils.cuh" | |
| ) | |
| # TODO: check if CLion support this: | |
| # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories | |
| set_source_files_properties( | |
| ${_qattn_sm89_SRC} | |
| PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.") | |
| if(GPU_LANG STREQUAL "CUDA") | |
| cuda_archs_loose_intersection(_qattn_sm89_ARCHS "8.9" "${CUDA_ARCHS}") | |
| message(STATUS "Capabilities for kernel _qattn_sm89: ${_qattn_sm89_ARCHS}") | |
| set_gencode_flags_for_srcs(SRCS "${_qattn_sm89_SRC}" CUDA_ARCHS "${_qattn_sm89_ARCHS}") | |
| foreach(_KERNEL_SRC ${_qattn_sm89_SRC}) | |
| if(_KERNEL_SRC MATCHES ".*\\.cu$") | |
| set_property( | |
| SOURCE ${_KERNEL_SRC} | |
| APPEND PROPERTY | |
| COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>" | |
| ) | |
| endif() | |
| endforeach() | |
| foreach(_KERNEL_SRC ${_qattn_sm89_SRC}) | |
| set_property( | |
| SOURCE ${_KERNEL_SRC} | |
| APPEND PROPERTY | |
| COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>" | |
| ) | |
| endforeach() | |
| list(APPEND SRC "${_qattn_sm89_SRC}") | |
| endif() | |
| set(_fused_SRC | |
| "sage_attention/fused/fused.cu" | |
| "sage_attention/fused/fused.h" | |
| ) | |
| # TODO: check if CLion support this: | |
| # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories | |
| set_source_files_properties( | |
| ${_fused_SRC} | |
| PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.") | |
| if(GPU_LANG STREQUAL "CUDA") | |
| cuda_archs_loose_intersection(_fused_ARCHS "8.0;8.9;9.0;9.0a" "${CUDA_ARCHS}") | |
| message(STATUS "Capabilities for kernel _fused: ${_fused_ARCHS}") | |
| set_gencode_flags_for_srcs(SRCS "${_fused_SRC}" CUDA_ARCHS "${_fused_ARCHS}") | |
| foreach(_KERNEL_SRC ${_fused_SRC}) | |
| if(_KERNEL_SRC MATCHES ".*\\.cu$") | |
| set_property( | |
| SOURCE ${_KERNEL_SRC} | |
| APPEND PROPERTY | |
| COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>" | |
| ) | |
| endif() | |
| endforeach() | |
| foreach(_KERNEL_SRC ${_fused_SRC}) | |
| set_property( | |
| SOURCE ${_KERNEL_SRC} | |
| APPEND PROPERTY | |
| COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>" | |
| ) | |
| endforeach() | |
| list(APPEND SRC "${_fused_SRC}") | |
| endif() | |
| set(_qattn_SRC | |
| "sage_attention/cp_async.cuh" | |
| "sage_attention/dispatch_utils.h" | |
| "sage_attention/math.cuh" | |
| "sage_attention/mma.cuh" | |
| "sage_attention/numeric_conversion.cuh" | |
| "sage_attention/permuted_smem.cuh" | |
| "sage_attention/reduction_utils.cuh" | |
| "sage_attention/wgmma.cuh" | |
| "sage_attention/utils.cuh" | |
| "sage_attention/cuda_tensormap_shim.cuh" | |
| ) | |
| if(GPU_LANG STREQUAL "CUDA") | |
| cuda_archs_loose_intersection(_qattn_ARCHS "8.0;8.9;9.0;9.0a" "${CUDA_ARCHS}") | |
| message(STATUS "Capabilities for kernel _qattn: ${_qattn_ARCHS}") | |
| set_gencode_flags_for_srcs(SRCS "${_qattn_SRC}" CUDA_ARCHS "${_qattn_ARCHS}") | |
| foreach(_KERNEL_SRC ${_qattn_SRC}) | |
| if(_KERNEL_SRC MATCHES ".*\\.cu$") | |
| set_property( | |
| SOURCE ${_KERNEL_SRC} | |
| APPEND PROPERTY | |
| COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>" | |
| ) | |
| endif() | |
| endforeach() | |
| foreach(_KERNEL_SRC ${_qattn_SRC}) | |
| set_property( | |
| SOURCE ${_KERNEL_SRC} | |
| APPEND PROPERTY | |
| COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>" | |
| ) | |
| endforeach() | |
| list(APPEND SRC "${_qattn_SRC}") | |
| endif() | |
| define_gpu_extension_target( | |
| _sage_attention_57cb7ec_dirty | |
| DESTINATION _sage_attention_57cb7ec_dirty | |
| LANGUAGE ${GPU_LANG} | |
| SOURCES ${SRC} | |
| COMPILE_FLAGS ${GPU_FLAGS} | |
| ARCHITECTURES ${GPU_ARCHES} | |
| #INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} | |
| USE_SABI 3 | |
| WITH_SOABI) | |
| target_link_options(_sage_attention_57cb7ec_dirty PRIVATE -static-libstdc++) | |