@@ -168,14 +168,19 @@ struct vk_device_struct {
168168 uint32_t subgroup_size;
169169 uint32_t shader_core_count;
170170 bool uma;
171- bool coopmat2;
171+
172+ bool subgroup_size_control;
173+ uint32_t subgroup_min_size;
174+ uint32_t subgroup_max_size;
175+ bool subgroup_require_full_support;
172176
173177 bool coopmat_support;
174178 bool coopmat_acc_f32_support;
175179 bool coopmat_acc_f16_support;
176180 uint32_t coopmat_m;
177181 uint32_t coopmat_n;
178182 uint32_t coopmat_k;
183+ bool coopmat2;
179184
180185 size_t idx;
181186
@@ -753,8 +758,12 @@ static uint32_t compile_count = 0;
753758static std::mutex compile_count_mutex;
754759static std::condition_variable compile_count_cond;
755760
756- static void ggml_vk_create_pipeline_func (vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void * spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t , 3 > wg_denoms, std::vector<uint32_t > specialization_constants, uint32_t align, bool disable_robustness) {
757- VK_LOG_DEBUG (" ggml_vk_create_pipeline(" << device->name << " , " << name << " , " << entrypoint << " , " << parameter_count << " , " << push_constant_size << " , (" << wg_denoms[0 ] << " ," << wg_denoms[1 ] << " ," << wg_denoms[2 ] << " ), specialization_constants, " << align << " )" );
761+ static void ggml_vk_create_pipeline_func (vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void * spv_data, const std::string entrypoint,
762+ uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t , 3 > wg_denoms, std::vector<uint32_t > specialization_constants,
763+ uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
764+ VK_LOG_DEBUG (" ggml_vk_create_pipeline(" << device->name << " , " << name << " , " << entrypoint << " , " << parameter_count << " , " << push_constant_size <<
765+ " , (" << wg_denoms[0 ] << " ," << wg_denoms[1 ] << " ," << wg_denoms[2 ] << " ), specialization_constants, " << align <<
766+ " , " << disable_robustness << " , " << require_full_subgroups << " , " << required_subgroup_size << " )" );
758767 GGML_ASSERT (parameter_count > 0 );
759768 GGML_ASSERT (wg_denoms[0 ] > 0 && wg_denoms[1 ] > 0 && wg_denoms[2 ] > 0 ); // NOLINT
760769
@@ -813,14 +822,28 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
813822 specialization_constants.data ()
814823 );
815824
825+ vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{};
826+
827+ if (device->subgroup_require_full_support && require_full_subgroups) {
828+ pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT;
829+ }
830+
816831 vk::PipelineShaderStageCreateInfo pipeline_shader_create_info (
817- vk::PipelineShaderStageCreateFlags () ,
832+ pipeline_shader_stage_create_flags ,
818833 vk::ShaderStageFlagBits::eCompute,
819834 pipeline->shader_module ,
820835 entrypoint.c_str (),
821836 &specialization_info);
837+
838+ vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info;
839+ pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size;
840+ if (device->subgroup_size_control && required_subgroup_size > 0 ) {
841+ GGML_ASSERT (device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size );
842+ pipeline_shader_create_info.setPNext (&pipeline_shader_stage_required_subgroup_size_create_info);
843+ }
844+
822845 vk::ComputePipelineCreateInfo compute_pipeline_create_info (
823- vk::PipelineCreateFlags () ,
846+ vk::PipelineCreateFlags{} ,
824847 pipeline_shader_create_info,
825848 pipeline->layout );
826849
@@ -1500,7 +1523,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
15001523 device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
15011524
15021525 std::vector<std::future<void >> compiles;
1503- auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void * spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t , 3 > wg_denoms, const std::vector<uint32_t >& specialization_constants, uint32_t align, bool disable_robustness = false ) {
1526+ auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void * spv_data, const std::string &entrypoint,
1527+ uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t , 3 > wg_denoms, const std::vector<uint32_t >& specialization_constants,
1528+ uint32_t align, bool disable_robustness = false , bool require_full_subgroups = false , uint32_t required_subgroup_size = 0 ) {
15041529 {
15051530 // wait until fewer than N compiles are in progress
15061531 uint32_t N = std::max (1u , std::thread::hardware_concurrency ());
@@ -1510,7 +1535,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
15101535 }
15111536 compile_count++;
15121537 }
1513- compiles.push_back (std::async (ggml_vk_create_pipeline_func, std::ref (device), std::ref (pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
1538+ compiles.push_back (std::async (ggml_vk_create_pipeline_func, std::ref (device), std::ref (pipeline), name, spv_size, spv_data, entrypoint,
1539+ parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size));
15141540 };
15151541
15161542#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -1616,17 +1642,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
16161642 // Create 6 variants, {s,m,l}x{unaligned,aligned}
16171643#define CREATE_MM (PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID ) \
16181644 if (device->mul_mat ## ID ## _l) \
1619- ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->l , #NAMELC #F16ACC " _l" , NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, " main" , PARAMCOUNT, sizeof (PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1 ); \
1645+ ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->l , #NAMELC #F16ACC " _l" , NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, " main" , PARAMCOUNT, sizeof (PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1 , false , true ); \
16201646 if (device->mul_mat ## ID ## _m) \
1621- ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->m , #NAMELC #F16ACC " _m" , NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, " main" , PARAMCOUNT, sizeof (PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1 ); \
1647+ ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->m , #NAMELC #F16ACC " _m" , NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, " main" , PARAMCOUNT, sizeof (PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1 , false , true ); \
16221648 if (device->mul_mat ## ID ## _s) \
1623- ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->s , #NAMELC #F16ACC " _s" , NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, " main" , PARAMCOUNT, sizeof (PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1 ); \
1649+ ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->s , #NAMELC #F16ACC " _s" , NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, " main" , PARAMCOUNT, sizeof (PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1 , false , true ); \
16241650 if (device->mul_mat ## ID ## _l) \
1625- ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->a_l , #NAMELC #F16ACC " _aligned_l" , NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, " main" , PARAMCOUNT, sizeof (PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1651+ ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->a_l , #NAMELC #F16ACC " _aligned_l" , NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, " main" , PARAMCOUNT, sizeof (PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false , true ); \
16261652 if (device->mul_mat ## ID ## _m) \
1627- ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->a_m , #NAMELC #F16ACC " _aligned_m" , NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, " main" , PARAMCOUNT, sizeof (PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1653+ ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->a_m , #NAMELC #F16ACC " _aligned_m" , NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, " main" , PARAMCOUNT, sizeof (PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false , true ); \
16281654 if (device->mul_mat ## ID ## _s) \
1629- ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->a_s , #NAMELC #F16ACC " _aligned_s" , NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, " main" , PARAMCOUNT, sizeof (PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1655+ ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->a_s , #NAMELC #F16ACC " _aligned_s" , NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, " main" , PARAMCOUNT, sizeof (PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false , true ); \
16301656
16311657 // Create 2 variants, {f16,f32} accumulator
16321658#define CREATE_MM2 (PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID ) \
@@ -1993,6 +2019,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
19932019 amd_shader_core_properties2 = true ;
19942020 } else if (strcmp (" VK_EXT_pipeline_robustness" , properties.extensionName ) == 0 ) {
19952021 pipeline_robustness = true ;
2022+ } else if (strcmp (" VK_EXT_subgroup_size_control" , properties.extensionName ) == 0 ) {
2023+ device->subgroup_size_control = true ;
19962024 } else if (strcmp (" VK_KHR_cooperative_matrix" , properties.extensionName ) == 0 &&
19972025 !getenv (" GGML_VK_DISABLE_COOPMAT" )) {
19982026 device->coopmat_support = true ;
@@ -2012,6 +2040,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
20122040 vk::PhysicalDeviceDriverProperties driver_props;
20132041 vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
20142042 vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2043+ vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2044+
20152045 props2.pNext = &props3;
20162046 props3.pNext = &subgroup_props;
20172047 subgroup_props.pNext = &driver_props;
@@ -2030,6 +2060,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
20302060 last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
20312061 last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
20322062 }
2063+ if (device->subgroup_size_control ) {
2064+ last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props;
2065+ last_struct = (VkBaseOutStructure *)&subgroup_size_control_props;
2066+ }
20332067
20342068#if defined(VK_NV_cooperative_matrix2)
20352069 vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
@@ -2067,11 +2101,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
20672101
20682102 device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
20692103
2070- if (device->vendor_id == VK_VENDOR_ID_INTEL || (props2. properties . vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2071- // Intel drivers don't support coopmat properly yet
2072- // Only RADV supports coopmat properly on AMD
2073- device->coopmat_support = false ;
2074- }
2104+ // if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2105+ // // Intel drivers don't support coopmat properly yet
2106+ // // Only RADV supports coopmat properly on AMD
2107+ // device->coopmat_support = false;
2108+ // }
20752109
20762110 std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device .getQueueFamilyProperties ();
20772111
@@ -2123,6 +2157,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
21232157 device_extensions.push_back (" VK_EXT_pipeline_robustness" );
21242158 }
21252159
2160+ VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
2161+ subgroup_size_control_features.pNext = nullptr ;
2162+ subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
2163+ subgroup_size_control_features.computeFullSubgroups = false ;
2164+ subgroup_size_control_features.subgroupSizeControl = false ;
2165+
2166+ if (device->subgroup_size_control ) {
2167+ last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features;
2168+ last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
2169+ }
2170+
21262171 VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
21272172 coopmat_features.pNext = nullptr ;
21282173 coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
@@ -2150,6 +2195,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
21502195
21512196 device->pipeline_robustness = pl_robustness_features.pipelineRobustness ;
21522197
2198+ device->subgroup_size_control = device->subgroup_size_control &&
2199+ (!(subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) ||
2200+ !subgroup_size_control_features.subgroupSizeControl );
2201+
2202+ if (device->subgroup_size_control ) {
2203+ device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize ;
2204+ device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize ;
2205+ device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups ;
2206+ device_extensions.push_back (" VK_EXT_subgroup_size_control" );
2207+ }
2208+
21532209 device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix ;
21542210
21552211 if (coopmat2_support) {
@@ -2430,11 +2486,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
24302486 }
24312487 }
24322488
2433- if (props2.properties .vendorID == VK_VENDOR_ID_INTEL || (props2.properties .vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2434- // Intel drivers don't support coopmat properly yet
2435- // Only RADV supports coopmat properly on AMD
2436- coopmat_support = false ;
2437- }
2489+ // if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2490+ // // Intel drivers don't support coopmat properly yet
2491+ // // Only RADV supports coopmat properly on AMD
2492+ // coopmat_support = false;
2493+ // }
24382494
24392495 const char * GGML_VK_DISABLE_F16 = getenv (" GGML_VK_DISABLE_F16" );
24402496 bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr ;
0 commit comments