8000 [INFER] Add cutlass fp8 gemm auto tune by Sunny-bot1 · Pull Request #9020 · PaddlePaddle/PaddleNLP · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[INFER] Add cutlass fp8 gemm auto tune #9020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
583 changes: 583 additions & 0 deletions csrc/generate_code_dual_gemm_fused_kernels.py

Large diffs are not rendered by default.

36 changes: 18 additions & 18 deletions csrc/generate_code_gemm_fused_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,15 @@ def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):

COMMON_DECLARE_string(use_cutlass_device_best_config_path);

std::map<std::string, int> config_map{"""
std::map<std::string, int> gemm_type_map{"""

code_part1 = """
{"{input_type}_{output_type}_{hasbias}_{act_tag}", {type_id}}, """

code_part2 = """
};

std::map<std::string, int> gemm_configs_map{
std::map<std::string, int> gemm_config_map{
"""

code_part3 = """ {"{thread_block_shape}, {warp_shape}, {mma_shape}, {num_stages}", {tile_id}},
Expand Down Expand Up @@ -165,40 +165,40 @@ def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):


bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) {
if (config_map.find(params.fuse_gemm_config) == config_map.end()) {
if (gemm_type_map.find(params.fuse_gemm_config) == gemm_type_map.end()) {
throw std::runtime_error("fp8 gemm_fused config is invalid.");
}

int type_id = config_map[params.fuse_gemm_config];
int type_id = gemm_type_map[params.fuse_gemm_config];
int M = (params.M+31)/32 *32;
int N = params.N;
int K = params.K;

std::string mkn_string = "<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
std::string mkn_split_k_string = "<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k";
std::string mnk_string = "gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">";
std::string mnk_split_k_string = "gemm<"+ std::to_string(M)+ ", " +std::to_string(N) + ", "+ std::to_string(K)+ ">" + ", split_k";
int split_k;
int kernel_id;
std::string best_config;
CutlassGemmConfigMannager& best_config_mannager = CutlassGemmConfigMannager::getInstance();
if(getenv("FLAGS_use_cutlass_device_best_config_path")){ // run kernel
std::string config_file_path = getenv("FLAGS_use_cutlass_device_best_config_path");
nlohmann::json* config_json = best_config_mannager.get_gemm_best_configs(config_file_path);
if (config_json->contains(mkn_string)) {
best_config = config_json->at(mkn_string);
if (config_json->contains(mnk_string)) {
best_config = config_json->at(mnk_string);
} else {
std::cerr << "Can not find the config for this gemm shape, please tune this shape: " << mkn_string <<std::endl;
std::cerr << "Can not find the config for this gemm shape, please tune this shape: " << mnk_string <<std::endl;
}

if (config_json->contains(mkn_split_k_string)) {
split_k = config_json->at(mkn_split_k_string);
if (config_json->contains(mnk_split_k_string)) {
split_k = config_json->at(mnk_split_k_string);
} else {
std::cerr << "Can not find the config(split_k) for this gemm shape, please tune this shape: " << mkn_string <<std::endl;
std::cerr 8000 << "Can not find the config(split_k) for this gemm shape, please tune this shape: " << mnk_string <<std::endl;
}

if (gemm_configs_map.find(best_config) == gemm_configs_map.end()) {
if (gemm_config_map.find(best_config) == gemm_config_map.end()) {
throw std::runtime_error("This config'kernel not be generate, please check generate_code_gemm_fused_kernels.py and re-generate.");
} else {
kernel_id = gemm_configs_map[best_config];
kernel_id = gemm_config_map[best_config];
}
return launch_gemm_kernel(type_id, split_k, kernel_id, params);
} else { // tune kernel
Expand All @@ -209,7 +209,7 @@ def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
float duratation = 1000000.f;
// tune all split_k, kernel_id kernels
for(int i = 1; i < {max_split_k}+1; ++i){ // all split_k
for(const auto& config_pair : gemm_configs_map){
for(const auto& config_pair : gemm_config_map){
bool is_valid = true;
// warm up
for(int num_time = 0; num_time < warm_up_times; ++num_time){
Expand Down Expand Up @@ -251,10 +251,10 @@ def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
}

nlohmann::json new_json;
new_json[mkn_string] = best_kernel_id;
new_json[mkn_split_k_string] = best_split_k;
new_json[mnk_string] = best_kernel_id;
new_json[mnk_split_k_string] = best_split_k;
best_config_mannager.up_date_configs(new_json);
std::cout <<"Gemm tune result for " << mkn_string<< ": best config is: "<< best_kernel_id << ", split k: " << best_split_k << std::endl;
std::cout <<"Gemm tune result for " << mnk_string<< ": best config is: "<< best_kernel_id << ", split k: " << best_split_k << std::endl;
return true;
}
}
Expand Down

This file was deleted.

This file was deleted.

Loading
Loading
0