Skip to content

Commit

Permalink
Merge pull request #506 from makllama/musa
Browse files Browse the repository at this point in the history
feat: Support Moore Threads GPU
  • Loading branch information
Azure-Tang authored Feb 20, 2025
2 parents 1dd84b4 + 2207f6c commit 25c5bdd
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 34 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ ktransformers/tests/chat_txt.txt
mmlu_result_q4km.json
mmlu_result_q4km.log
ktransformers/tests/mmlu_result_silicon.log
ktransformers/ktransformers_ext/cuda_musa/
40 changes: 35 additions & 5 deletions ktransformers/ktransformers_ext/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ if (NOT MSVC)
option(LLAMA_F16C "llama: enable F16C" OFF)
endif()
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)

# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
Expand Down Expand Up @@ -208,8 +210,31 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
if (WIN32)
include_directories("$ENV{CUDA_PATH}/include")
elseif (UNIX)
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
if (KTRANSFORMERS_USE_CUDA)
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
endif()

if (KTRANSFORMERS_USE_MUSA)
if (NOT EXISTS $ENV{MUSA_PATH})
if (NOT EXISTS /opt/musa)
set(MUSA_PATH /usr/local/musa)
else()
set(MUSA_PATH /opt/musa)
endif()
else()
set(MUSA_PATH $ENV{MUSA_PATH})
endif()

list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")

find_package(MUSAToolkit)
if (MUSAToolkit_FOUND)
message(STATUS "MUSA Toolkit found")
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
endif()
endif()
endif()

aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
Expand All @@ -225,10 +250,15 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
elseif(UNIX)
if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "")
set(ENV{CUDA_HOME} "/usr/local/cuda")
if(KTRANSFORMERS_USE_CUDA)
if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "")
set(ENV{CUDA_HOME} "/usr/local/cuda")
endif()
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif()
if(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
endif()
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif()

# Define the USE_NUMA option
Expand Down
6 changes: 5 additions & 1 deletion ktransformers/ktransformers_ext/cpu_backend/cpuinfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
#include <queue>
#include <thread>
#include <vector>
#include "cuda_runtime.h"
#ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
#include "vendors/musa.h"
#endif

#include "backend.h"
#include "task_queue.h"
Expand Down
3 changes: 3 additions & 0 deletions ktransformers/ktransformers_ext/cpu_backend/vendors/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## TODO

This directory can be removed after updating the version of `llama.cpp`.
3 changes: 3 additions & 0 deletions ktransformers/ktransformers_ext/cpu_backend/vendors/cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#pragma once

#include <cuda_runtime.h>
7 changes: 7 additions & 0 deletions ktransformers/ktransformers_ext/cpu_backend/vendors/musa.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#pragma once

#include <musa_runtime.h>

#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaStream_t musaStream_t
#define cudaHostFn_t musaHostFn_t
8 changes: 6 additions & 2 deletions ktransformers/ktransformers_ext/cuda/binding.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
/**
* @Description :
* @Description :
* @Author : Azure-Tang
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-12 03:05:04
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/

#include "custom_gguf/ops.h"
#ifdef KTRANSFORMERS_USE_CUDA
#include "gptq_marlin/ops.h"
#endif
// Python bindings
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand All @@ -33,8 +35,10 @@ PYBIND11_MODULE(KTransformersOps, m) {
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
#ifdef KTRANSFORMERS_USE_CUDA
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"));
#endif
}
111 changes: 85 additions & 26 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Description :
Author : chenxl
Date : 2024-07-27 16:15:27
Version : 1.0.0
LastEditors : chenxl
LastEditors : chenxl
LastEditTime : 2024-08-14 16:36:19
Adapted from:
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
Copyright (c) 2023, Tri Dao.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''

import os
Expand All @@ -30,6 +30,11 @@
from setuptools import setup, Extension
from cpufeature.extension import CPUFeature
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
try:
from torch_musa.utils.simple_porting import SimplePorting
from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME
except ImportError:
MUSA_HOME=None

class CpuInstructInfo:
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
Expand All @@ -40,7 +45,7 @@ class CpuInstructInfo:
CMAKE_FANCY = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON"
CMAKE_AVX512 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON"
CMAKE_AVX2 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON"

class VersionInfo:
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "ktransformers"
Expand All @@ -49,6 +54,16 @@ class VersionInfo:
)
FORCE_BUILD = os.getenv("KTRANSFORMERS_FORCE_BUILD", "FALSE") == "TRUE"

def get_musa_bare_metal_version(self, musa_dir):
raw_output = subprocess.run(
[musa_dir + "/bin/mcc", "-v"], check=True,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.decode("utf-8")
output = raw_output.split()
release_idx = output.index("version") + 1
bare_metal_version = parse(output[release_idx].split(",")[0])
musa_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
return musa_version

def get_cuda_bare_metal_version(self, cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
Expand All @@ -58,7 +73,7 @@ def get_cuda_bare_metal_version(self, cuda_dir):
cuda_version = f"{bare_metal_version.major}{bare_metal_version.minor}"
return cuda_version

def get_cuda_version_of_torch(self,):
def get_cuda_version_of_torch(self):
torch_cuda_version = parse(torch.version.cuda)
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
return cuda_version
Expand Down Expand Up @@ -117,7 +132,7 @@ def get_torch_version(self,):
torch_version_raw = parse(torch.__version__)
torch_version = f"{torch_version_raw.major}{torch_version_raw.minor}"
return torch_version

def get_flash_version(self,):
version_file = os.path.join(
Path(VersionInfo.THIS_DIR), VersionInfo.PACKAGE_NAME, "__init__.py")
Expand All @@ -128,12 +143,21 @@ def get_flash_version(self,):
return flash_version

def get_package_version(self, full_version=False):
flash_version = self.get_flash_version()
package_version = f"{str(flash_version)}+cu{self.get_cuda_bare_metal_version(CUDA_HOME)}torch{self.get_torch_version()}{self.get_cpu_instruct()}"
flash_version = str(self.get_flash_version())
torch_version = self.get_torch_version()
cpu_instruct = self.get_cpu_instruct()
backend_version = ""
if CUDA_HOME is not None:
backend_version = f"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}"
elif MUSA_HOME is not None:
backend_version = f"mu{self.get_musa_bare_metal_version(MUSA_HOME)}"
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")
package_version = f"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}"
if full_version:
return package_version
if not VersionInfo.FORCE_BUILD:
return str(flash_version)
return flash_version
return package_version


Expand Down Expand Up @@ -218,11 +242,19 @@ def build_extension(self, ext) -> None:
f"-DPYTHON_EXECUTABLE={sys.executable}",
f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
]

if CUDA_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_CUDA=ON"]
elif MUSA_HOME is not None:
cmake_args += ["-DKTRANSFORMERS_USE_MUSA=ON"]
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")

build_args = []
if "CMAKE_ARGS" in os.environ:
cmake_args += [
item for item in os.environ["CMAKE_ARGS"].split(" ") if item]

if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:
cpu_args = CpuInstructInfo.CMAKE_FANCY
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:
Expand All @@ -231,7 +263,7 @@ def build_extension(self, ext) -> None:
cpu_args = CpuInstructInfo.CMAKE_AVX2
else:
cpu_args = CpuInstructInfo.CMAKE_NATIVE

cmake_args += [
item for item in cpu_args.split(" ") if item
]
Expand Down Expand Up @@ -288,28 +320,55 @@ def build_extension(self, ext) -> None:
print("Standard output:", result.stdout)
print("Standard error:", result.stderr)
subprocess.run(
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
)

if CUDA_HOME is not None:
ops_module = CUDAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp',
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
],
extra_compile_args={
'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],
'nvcc': [
'-O3',
'--use_fast_math',
'-Xcompiler', '-fPIC',
'-DKTRANSFORMERS_USE_CUDA',
]
}
)
elif MUSA_HOME is not None:
SimplePorting(cuda_dir_path="ktransformers/ktransformers_ext/cuda", mapping_rule={
# Common rules
"at::cuda": "at::musa",
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
}).run()
ops_module = MUSAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',
'ktransformers/ktransformers_ext/cuda_musa/binding.cpp',
# TODO: Add Marlin support for MUSA.
# 'ktransformers/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'
],
extra_compile_args={
'cxx': ['force_mcc'],
'mcc': [
'-O3',
'-DKTRANSFORMERS_USE_MUSA',
'-DTHRUST_IGNORE_CUB_VERSION_CHECK',
]
}
)
else:
raise ValueError("Unsupported backend: CUDA_HOME and MUSA_HOME are not set.")

setup(
version=VersionInfo().get_package_version(),
cmdclass={"bdist_wheel":BuildWheelsCommand ,"build_ext": CMakeBuild},
ext_modules=[
CMakeExtension("cpuinfer_ext"),
CUDAExtension('KTransformersOps', [
'ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu',
'ktransformers/ktransformers_ext/cuda/binding.cpp',
'ktransformers/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': [
'-O3',
'--use_fast_math',
'-Xcompiler', '-fPIC',
]
}
)
ops_module,
]
)

0 comments on commit 25c5bdd

Please sign in to comment.