sage_attention / CMakeLists.txt
medmekk's picture
medmekk HF Staff
add some builds
af2d0c0
raw
history blame
10.6 kB
cmake_minimum_required(VERSION 3.26)
project(sage_attention LANGUAGES CXX)
set(TARGET_DEVICE "cuda" CACHE STRING "Target device backend for kernel")
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
include(FetchContent)
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
if(DEFINED Python_EXECUTABLE)
# Allow passing through the interpreter (e.g. from setup.py).
find_package(Python COMPONENTS Development Development.SABIModule Interpreter)
if (NOT Python_FOUND)
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
endif()
else()
find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
endif()
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
find_package(Torch REQUIRED)
if (NOT TARGET_DEVICE STREQUAL "cuda" AND
NOT TARGET_DEVICE STREQUAL "rocm")
return()
endif()
if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0+PTX")
else()
set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0+PTX")
endif()
if (NOT HIP_FOUND AND CUDA_FOUND)
set(GPU_LANG "CUDA")
elseif(HIP_FOUND)
set(GPU_LANG "HIP")
# Importing torch recognizes and sets up some HIP/ROCm configuration but does
# not let cmake recognize .hip files. In order to get cmake to understand the
# .hip extension automatically, HIP must be enabled explicitly.
enable_language(HIP)
else()
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
endif()
if(GPU_LANG STREQUAL "CUDA")
clear_cuda_arches(CUDA_ARCH_FLAGS)
extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}")
message(STATUS "CUDA target architectures: ${CUDA_ARCHS}")
# Filter the target architectures by the supported supported archs
# since for some files we will build for all CUDA_ARCHS.
cuda_archs_loose_intersection(CUDA_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}")
if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA")
list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}")
endif()
add_compile_definitions(CUDA_KERNEL)
elseif(GPU_LANG STREQUAL "HIP")
set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}")
# TODO: remove this once we can set specific archs per source file set.
override_gpu_arches(GPU_ARCHES
${GPU_LANG}
"${${GPU_LANG}_SUPPORTED_ARCHS}")
add_compile_definitions(ROCM_KERNEL)
else()
override_gpu_arches(GPU_ARCHES
${GPU_LANG}
"${${GPU_LANG}_SUPPORTED_ARCHS}")
endif()
get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
set(TORCH_sage_attention_SRC
torch-ext/torch_binding.cpp torch-ext/torch_binding.h
)
list(APPEND SRC "${TORCH_sage_attention_SRC}")
set(_qattn_sm80_SRC
"sage_attention/qattn/qk_int_sv_f16_cuda_sm80.cu"
"sage_attention/qattn/attn_cuda_sm80.h"
"sage_attention/qattn/attn_utils.cuh"
)
# TODO: check if CLion support this:
# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
set_source_files_properties(
${_qattn_sm80_SRC}
PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.")
if(GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(_qattn_sm80_ARCHS "8.0" "${CUDA_ARCHS}")
message(STATUS "Capabilities for kernel _qattn_sm80: ${_qattn_sm80_ARCHS}")
set_gencode_flags_for_srcs(SRCS "${_qattn_sm80_SRC}" CUDA_ARCHS "${_qattn_sm80_ARCHS}")
foreach(_KERNEL_SRC ${_qattn_sm80_SRC})
if(_KERNEL_SRC MATCHES ".*\\.cu$")
set_property(
SOURCE ${_KERNEL_SRC}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>"
)
endif()
endforeach()
foreach(_KERNEL_SRC ${_qattn_sm80_SRC})
set_property(
SOURCE ${_KERNEL_SRC}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>"
)
endforeach()
list(APPEND SRC "${_qattn_sm80_SRC}")
endif()
set(_qattn_sm90_SRC
"sage_attention/qattn/qk_int_sv_f8_cuda_sm90.cu"
"sage_attention/qattn/attn_cuda_sm90.h"
"sage_attention/qattn/attn_utils.cuh"
"sage_attention/cuda_tensormap_shim.cuh"
)
# TODO: check if CLion support this:
# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
set_source_files_properties(
${_qattn_sm90_SRC}
PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.")
if(GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(_qattn_sm90_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
message(STATUS "Capabilities for kernel _qattn_sm90: ${_qattn_sm90_ARCHS}")
set_gencode_flags_for_srcs(SRCS "${_qattn_sm90_SRC}" CUDA_ARCHS "${_qattn_sm90_ARCHS}")
foreach(_KERNEL_SRC ${_qattn_sm90_SRC})
if(_KERNEL_SRC MATCHES ".*\\.cu$")
set_property(
SOURCE ${_KERNEL_SRC}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>"
)
endif()
endforeach()
foreach(_KERNEL_SRC ${_qattn_sm90_SRC})
set_property(
SOURCE ${_KERNEL_SRC}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>"
)
endforeach()
list(APPEND SRC "${_qattn_sm90_SRC}")
endif()
set(_qattn_sm89_SRC
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu"
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu"
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu"
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu"
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu"
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu"
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
"sage_attention/qattn/attn_cuda_sm89.h"
"sage_attention/qattn/qk_int_sv_f8_cuda_sm89.cuh"
"sage_attention/qattn/attn_utils.cuh"
)
# TODO: check if CLion support this:
# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
set_source_files_properties(
${_qattn_sm89_SRC}
PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.")
if(GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(_qattn_sm89_ARCHS "8.9" "${CUDA_ARCHS}")
message(STATUS "Capabilities for kernel _qattn_sm89: ${_qattn_sm89_ARCHS}")
set_gencode_flags_for_srcs(SRCS "${_qattn_sm89_SRC}" CUDA_ARCHS "${_qattn_sm89_ARCHS}")
foreach(_KERNEL_SRC ${_qattn_sm89_SRC})
if(_KERNEL_SRC MATCHES ".*\\.cu$")
set_property(
SOURCE ${_KERNEL_SRC}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>"
)
endif()
endforeach()
foreach(_KERNEL_SRC ${_qattn_sm89_SRC})
set_property(
SOURCE ${_KERNEL_SRC}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>"
)
endforeach()
list(APPEND SRC "${_qattn_sm89_SRC}")
endif()
set(_fused_SRC
"sage_attention/fused/fused.cu"
"sage_attention/fused/fused.h"
)
# TODO: check if CLion support this:
# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
set_source_files_properties(
${_fused_SRC}
PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.")
if(GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(_fused_ARCHS "8.0;8.9;9.0;9.0a" "${CUDA_ARCHS}")
message(STATUS "Capabilities for kernel _fused: ${_fused_ARCHS}")
set_gencode_flags_for_srcs(SRCS "${_fused_SRC}" CUDA_ARCHS "${_fused_ARCHS}")
foreach(_KERNEL_SRC ${_fused_SRC})
if(_KERNEL_SRC MATCHES ".*\\.cu$")
set_property(
SOURCE ${_KERNEL_SRC}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>"
)
endif()
endforeach()
foreach(_KERNEL_SRC ${_fused_SRC})
set_property(
SOURCE ${_KERNEL_SRC}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>"
)
endforeach()
list(APPEND SRC "${_fused_SRC}")
endif()
set(_qattn_SRC
"sage_attention/cp_async.cuh"
"sage_attention/dispatch_utils.h"
"sage_attention/math.cuh"
"sage_attention/mma.cuh"
"sage_attention/numeric_conversion.cuh"
"sage_attention/permuted_smem.cuh"
"sage_attention/reduction_utils.cuh"
"sage_attention/wgmma.cuh"
"sage_attention/utils.cuh"
"sage_attention/cuda_tensormap_shim.cuh"
)
if(GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(_qattn_ARCHS "8.0;8.9;9.0;9.0a" "${CUDA_ARCHS}")
message(STATUS "Capabilities for kernel _qattn: ${_qattn_ARCHS}")
set_gencode_flags_for_srcs(SRCS "${_qattn_SRC}" CUDA_ARCHS "${_qattn_ARCHS}")
foreach(_KERNEL_SRC ${_qattn_SRC})
if(_KERNEL_SRC MATCHES ".*\\.cu$")
set_property(
SOURCE ${_KERNEL_SRC}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>"
)
endif()
endforeach()
foreach(_KERNEL_SRC ${_qattn_SRC})
set_property(
SOURCE ${_KERNEL_SRC}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>"
)
endforeach()
list(APPEND SRC "${_qattn_SRC}")
endif()
define_gpu_extension_target(
_sage_attention_57cb7ec_dirty
DESTINATION _sage_attention_57cb7ec_dirty
LANGUAGE ${GPU_LANG}
SOURCES ${SRC}
COMPILE_FLAGS ${GPU_FLAGS}
ARCHITECTURES ${GPU_ARCHES}
#INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)
target_link_options(_sage_attention_57cb7ec_dirty PRIVATE -static-libstdc++)