Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
metascroy committed Feb 5, 2025
1 parent a1572fd commit 7a43be4
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 228 deletions.
305 changes: 189 additions & 116 deletions torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#pragma once
#include <cpuinfo.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>

#if defined(__aarch64__) || defined(__ARM_NEON)
Expand All @@ -28,105 +27,200 @@

namespace torchao::ops::linear_8bit_act_xbit_weight {

namespace {
using UKernelConfigCacheKey = torchao::ops::PackedWeightsFormat;
using UKernelConfigCacheType = std::unordered_map<UKernelConfigCacheKey, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig>;
}
struct UniversalPackedWeightsFormat {
int version;
int weight_nbit;
bool has_weight_zeros;
bool has_bias;
int nr;
int kr;

static UniversalPackedWeightsFormat from_packed_weights_format(torchao::ops::PackedWeightsFormat format) {
if (format.type != torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal) {
throw std::runtime_error("Packed weights are not in universal packing format.");
}
return UniversalPackedWeightsFormat{
format.params[0],
format.params[1],
static_cast<bool>(format.params[2]),
static_cast<bool>(format.params[3]),
format.params[4],
format.params[5],
};
}
inline torchao::ops::PackedWeightsFormat to_packed_weights_format() const {
return torchao::ops::PackedWeightsFormat(
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal,
{
version,
weight_nbit,
has_weight_zeros,
has_bias,
nr,
kr
});
}
};

struct KleidiAIPackedWeightsFormat {
int weight_nbit;
bool has_weight_zeros;
bool has_bias;
int nr;
int kr;
int sr;

static KleidiAIPackedWeightsFormat from_packed_weights_format(torchao::ops::PackedWeightsFormat format) {
if (format.type != torchao::ops::PackedWeightsType::kleidi_ai) {
throw std::runtime_error("Packed weights are not in kleidi_ai packing format.");
}
return KleidiAIPackedWeightsFormat{
format.params[0],
static_cast<bool>(format.params[1]),
static_cast<bool>(format.params[2]),
format.params[3],
format.params[4],
format.params[5]
};
}
inline torchao::ops::PackedWeightsFormat to_packed_weights_format() const {
return torchao::ops::PackedWeightsFormat(
torchao::ops::PackedWeightsType::kleidi_ai,
{weight_nbit,
has_weight_zeros,
has_bias,
nr,
kr,
sr});
}
};

struct UKernelConfigRegistrationTable {
private:
std::unordered_map<torchao::ops::PackedWeightsFormat, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig> registration_table_;
public:
void register_ukernel_config(torchao::ops::PackedWeightsFormat format, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config) {
if (registration_table_.find(format) != registration_table_.end()) {
throw std::runtime_error("UKernelConfig is already registered for this format");
}
registration_table_[format] = config;
}
std::optional<torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig> get_ukernel_config(torchao::ops::PackedWeightsFormat format) const {
auto it = registration_table_.find(format);
if (it == registration_table_.end()) {
return std::nullopt;
}
return it->second;
}
};

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
void register_ukernel_config_universal(UKernelConfigCacheType& ukernel_config_cache, int nr, int kr, int version) {
void register_ukernel_config_universal(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) {
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
UKernelConfigCacheKey key = torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_universal(weight_nbit, has_weight_zeros, has_bias, nr, kr);
auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format(format);
if (universal_format.weight_nbit != weight_nbit) {
throw std::runtime_error("Packed weights are not in the expected format");
}
if (universal_format.has_weight_zeros != has_weight_zeros) {
throw std::runtime_error("Packed weights are not in the expected format");
}
if (universal_format.has_bias != has_bias) {
throw std::runtime_error("Packed weights are not in the expected format");
}

if (cpuinfo_has_arm_neon_dot()) {
if (nr == 8 && kr == 16) {
if (universal_format.nr == 8 && universal_format.kr == 16) {
#if defined(__aarch64__) || defined(__ARM_NEON)
if (cpuinfo_has_arm_neon_dot()) {
namespace kernel = torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;
ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
/*preferred_alignment*/16,
/*weight_packing*/
{
/*nr*/8,
/*weight_data_size_fn*/&kernel::weight_data_size<weight_nbit, has_weight_zeros, has_bias>,
/*prepare_weight_data_fn*/&kernel::prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>
},
/*kernels*/
{{
{
/*mr*/1,
/*activation_data_size_fn*/&kernel::activation_data_size<has_weight_zeros>,
/*prepare_activation_data_fn*/&kernel::prepare_activation_data<has_weight_zeros>,
/*kernel*/&kernel::kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>
table.register_ukernel_config(
format,
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
/*preferred_alignment*/16,
/*weight_packing*/
{
/*nr*/8,
/*weight_data_size_fn*/&kernel::weight_data_size<weight_nbit, has_weight_zeros, has_bias>,
/*prepare_weight_data_fn*/&kernel::prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>
},
/*kernels*/
{{
{
/*mr*/1,
/*activation_data_size_fn*/&kernel::activation_data_size<has_weight_zeros>,
/*prepare_activation_data_fn*/&kernel::prepare_activation_data<has_weight_zeros>,
/*kernel*/&kernel::kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>
}
}}
}
);
return;
}
}}
};
return;
}
#endif // defined(__aarch64__) || defined(__ARM_NEON)
}

throw std::runtime_error("Cannot register ukernel_config for packing format ukernel because no implementation is available on this platform");
}

template <int weight_nbit, bool has_weight_zeros, bool has_bias>
void register_ukernel_config_kleidi_ai(UKernelConfigCacheType& ukernel_config_cache, int nr, int kr, int sr) {
void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) {
std::cout << "register_ukernel_config_kleidi_ai" << std::endl;
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}

// TODO: make better
UKernelConfigCacheKey key = torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_kleidi_ai(weight_nbit, has_weight_zeros, has_bias, nr, kr, sr);
auto kleidi_ai_format = KleidiAIPackedWeightsFormat::from_packed_weights_format(format);
int nr = kleidi_ai_format.nr;
int kr = kleidi_ai_format.kr;
int sr = kleidi_ai_format.sr;

#if defined (TORCHAO_ENABLE_ARM_I8MM)
if (cpuinfo_has_arm_i8mm()) {
if (nr == 8 && kr == 16 && sr == 2) {
if (nr == 8 && kr == 16 && sr == 2) {
#if defined (TORCHAO_ENABLE_ARM_I8MM)
if (cpuinfo_has_arm_i8mm()) {
namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32;
auto uk = kernel::get_ukernel();
assert (nr == uk.get_nr());
assert (kr == uk.get_kr());
assert (sr == uk.get_sr());

ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
/*preferred_alignment*/16,
/*weight_packing*/
{
/*nr*/static_cast<int>(uk.get_n_step()),
/*weight_data_size_fn*/&kernel::weight_data_size,
/*prepare_weight_data_fn*/&kernel::prepare_weight_data
},
/*kernels*/
{{
{
/*mr*/static_cast<int>(uk.get_m_step()),
/*activation_data_size_fn*/&kernel::activation_data_size,
/*prepare_activation_data_fn*/&kernel::prepare_activation_data,
/*kernel*/&kernel::kernel
table.register_ukernel_config(
format,
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
/*preferred_alignment*/16,
/*weight_packing*/
{
/*nr*/static_cast<int>(uk.get_n_step()),
/*weight_data_size_fn*/&kernel::weight_data_size,
/*prepare_weight_data_fn*/&kernel::prepare_weight_data
},
/*kernels*/
{{
{
/*mr*/static_cast<int>(uk.get_m_step()),
/*activation_data_size_fn*/&kernel::activation_data_size,
/*prepare_activation_data_fn*/&kernel::prepare_activation_data,
/*kernel*/&kernel::kernel
}
}}
}
}}
};
);
return;
}
return;
}
#endif // TORCHAO_ENABLE_ARM_I8MM


if (cpuinfo_has_arm_neon_dot()) {
if (nr == 8 && kr == 16 && sr == 2) {
namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32;
auto uk = kernel::get_ukernel();
assert (nr == uk.get_nr());
assert (kr == uk.get_kr());
assert (sr == uk.get_sr());
}
#endif // TORCHAO_ENABLE_ARM_I8MM

ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
if (cpuinfo_has_arm_neon_dot()) {
namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32;
auto uk = kernel::get_ukernel();
assert (nr == uk.get_nr());
assert (kr == uk.get_kr());
assert (sr == uk.get_sr());
table.register_ukernel_config(
format,
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
/*preferred_alignment*/16,
/*weight_packing*/
{
/*nr*/static_cast<int>(uk.get_n_step()),
/*weight_data_size_fn*/&kernel::weight_data_size,
/*prepare_weight_data_fn*/&kernel::prepare_weight_data
/*nr*/static_cast<int>(uk.get_n_step()),
/*weight_data_size_fn*/&kernel::weight_data_size,
/*prepare_weight_data_fn*/&kernel::prepare_weight_data
},
/*kernels*/
{{
Expand All @@ -136,79 +230,58 @@ void register_ukernel_config_kleidi_ai(UKernelConfigCacheType& ukernel_config_ca
/*prepare_activation_data_fn*/&kernel::prepare_activation_data,
/*kernel*/&kernel::kernel
}
}}
};
return;
}

if (nr == 4 && kr == 8 && sr == 2) {
// TODO
return;
}}
}
);
return;
}
}


throw std::runtime_error("Cannot register ukernel_config for packing format kleidi_ai because no implementation is available on this platform");
}


template <int weight_nbit, bool has_weight_zeros>
void register_ukernel_config(UKernelConfigCacheType& ukernel_config_cache, torchao::ops::PackedWeightsFormat format) {
auto it = ukernel_config_cache.find(format);
if (it != ukernel_config_cache.end()) {
throw std::runtime_error("UKernel config already registered");
}

void register_ukernel_config(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) {
switch (format.type) {
case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal: {
auto packing_params = torchao::ops::linear_8bit_act_xbit_weight::get_universal_packing_params(format);
if (packing_params.weight_nbit != weight_nbit) {
throw std::runtime_error("Packed weights are not in the expected format");
}
if (packing_params.has_weight_zeros != has_weight_zeros) {
throw std::runtime_error("Packed weights are not in the expected format");
}
if (packing_params.has_bias) {
register_ukernel_config_universal<weight_nbit, has_weight_zeros, /*has_bias*/ true, /*has_clamp*/false>(ukernel_config_cache, packing_params.nr, packing_params.kr, packing_params.version);
auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format(format);
if (universal_format.has_bias) {
register_ukernel_config_universal<weight_nbit, has_weight_zeros, /*has_bias*/ true, /*has_clamp*/false>(table, format);
} else {
register_ukernel_config_universal<weight_nbit, has_weight_zeros, /*has_bias*/ false, /*has_clamp*/false>(ukernel_config_cache, packing_params.nr, packing_params.kr, packing_params.version);
register_ukernel_config_universal<weight_nbit, has_weight_zeros, /*has_bias*/ false, /*has_clamp*/false>(table, format);
}
break;
}
case torchao::ops::PackedWeightsType::kleidi_ai: {
auto packing_params = torchao::ops::linear_8bit_act_xbit_weight::get_kleidi_ai_packing_params(format);
assert (packing_params.has_bias == true);
register_ukernel_config_kleidi_ai<weight_nbit, has_weight_zeros, /*has_bias*/true>(ukernel_config_cache, packing_params.nr, packing_params.kr, packing_params.sr);
register_ukernel_config_kleidi_ai<weight_nbit, has_weight_zeros, /*has_bias*/true>(table, format);
break;
}
default:
throw std::runtime_error("No implementation for packed weights format");
}

it = ukernel_config_cache.find(format);
if (it == ukernel_config_cache.end()) {
auto config = table.get_ukernel_config(format);
if (!config.has_value()) {
throw std::runtime_error("UKernel config did not register");
}
}


template <int weight_nbit, bool has_weight_zeros>
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsFormat format) {
static UKernelConfigCacheType ukernel_config_cache;
static UKernelConfigRegistrationTable table;

// Check cache
auto it = ukernel_config_cache.find(format);
if (it != ukernel_config_cache.end()) {
std::cout << "UKERNEL CONFIG FROM CACHE: " << std::endl;
return it->second;
auto ukernel = table.get_ukernel_config(format);
if (ukernel.has_value()) {
std::cout << "FOUND UKERNEL CONFIG IN CACHE" << std::endl;
return ukernel.value();
}

std::cout << "REGISTERING UKERNEL CONFIG: " << std::endl;
register_ukernel_config<weight_nbit, has_weight_zeros>(ukernel_config_cache, format);
it = ukernel_config_cache.find(format);
assert(it != ukernel_config_cache.end());
auto config = it->second;
return config;
register_ukernel_config<weight_nbit, has_weight_zeros>(table, format);

ukernel = table.get_ukernel_config(format);
assert(ukernel.has_value());
return ukernel.value();
}

// TODO: make packing format and format separate concepts
Expand All @@ -223,15 +296,15 @@ torchao::ops::PackedWeightsFormat select_packed_weights_format(std::optional<std
#if defined(TORCHAO_ENABLE_KLEIDI)
if (!target || *target == "kleidi_ai") {
if (weight_nbit == 4 && !has_weight_zeros) {
return torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_kleidi_ai(weight_nbit, has_weight_zeros, /*has_bias*/true, /*nr*/8, /*kr*/16, /*sr*/2);
return KleidiAIPackedWeightsFormat({weight_nbit, has_weight_zeros, /*has_bias*/true, /*nr*/8, /*kr*/16, /*sr*/2}).to_packed_weights_format();
}
}
#endif // defined(TORCHAO_ENABLE_KLEIDI)

// Select universal format
if (!target || *target == "universal") {
if (cpuinfo_has_arm_neon_dot()) {
return torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_universal(weight_nbit, has_weight_zeros, has_bias, /*nr*/8, /*kr*/16, /*version*/1);
return UniversalPackedWeightsFormat({/*version*/1, weight_nbit, has_weight_zeros, has_bias, /*nr*/8, /*kr*/16}).to_packed_weights_format();
}
}

Expand Down
Loading

0 comments on commit 7a43be4

Please sign in to comment.