diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl index 821f7f79b0e..5fc58f03cd1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl @@ -26,7 +26,6 @@ layout(std430) buffer; #include "indexing.glslh" #include "common.glslh" -#include "conv2d_common.glslh" ${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=True)} ${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)} @@ -38,7 +37,6 @@ ${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False // Metadata for input/output tensors (memory layout agnostic) ${layout_declare_ubo(B, "BufferMetadata", "outp")} ${layout_declare_ubo(B, "BufferMetadata", "inp")} -${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} layout(push_constant) uniform restrict Block { float input_scale; @@ -56,6 +54,30 @@ ${layout_declare_spec_const(C, "int", "activation_type", "0")} ${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} +$if USE_SPEC_CONST: + // Conv2D parameter specialization constants + ${layout_declare_spec_const(C, "int", "kernel_size_x", "1")} + ${layout_declare_spec_const(C, "int", "kernel_size_y", "1")} + ${layout_declare_spec_const(C, "int", "stride_x", "1")} + ${layout_declare_spec_const(C, "int", "stride_y", "1")} + ${layout_declare_spec_const(C, "int", "padding_x", "0")} + ${layout_declare_spec_const(C, "int", "padding_y", "0")} + ${layout_declare_spec_const(C, "int", "dilation_x", "1")} + ${layout_declare_spec_const(C, "int", "dilation_y", "1")} + ${layout_declare_spec_const(C, "int", "groups", "1")} +$else: + #include "conv2d_common.glslh" + ${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + #define kernel_size_x conv2d_params.kernel_size.x + #define kernel_size_y conv2d_params.kernel_size.y + #define stride_x conv2d_params.stride.x + #define stride_y conv2d_params.stride.y + #define padding_x conv2d_params.padding.x + #define padding_y conv2d_params.padding.y + #define dilation_x conv2d_params.dilation.x + #define dilation_y conv2d_params.dilation.y + #define groups conv2d_params.groups + // Load weight block for a given (ic4, kx, ky, oc4) position. // Weight texture layout (from pack_q8_conv2d_weights.glsl): // block_x = oc4 * K_w + kx @@ -101,8 +123,8 @@ void main() { const int IC = int(inp.sizes[0][2]); // Compute channels per group - const int OC_per_group = OC / conv2d_params.groups; - const int IC_per_group = IC / conv2d_params.groups; + const int OC_per_group = OC / groups; + const int IC_per_group = IC / groups; const int IC4_per_group = div_up_4(IC_per_group); // Determine which group this output channel block belongs to @@ -113,14 +135,14 @@ void main() { const int inp_w_stride = int(inp.strides[0][0]); const int inp_h_stride = int(inp.strides[0][1]); const int inp_c_stride = int(inp.strides[0][2]); - const int w_texel_step = conv2d_params.dilation.x * inp_w_stride; - const int h_texel_step = conv2d_params.dilation.y * inp_h_stride; - const int subtile_w_step = conv2d_params.stride.x * inp_w_stride; + const int w_texel_step = dilation_x * inp_w_stride; + const int h_texel_step = dilation_y * inp_h_stride; + const int subtile_w_step = stride_x * inp_w_stride; // Compute base input position (for subtile_w=0, ic4=0) TensorIndex4D inp_tidx; - inp_tidx.data[0] = outp_tidx.data[0] * conv2d_params.stride.x - conv2d_params.padding.x; - inp_tidx.data[1] = outp_tidx.data[1] * conv2d_params.stride.y - conv2d_params.padding.y; + inp_tidx.data[0] = outp_tidx.data[0] * stride_x - padding_x; + inp_tidx.data[1] = outp_tidx.data[1] * stride_y - padding_y; inp_tidx.data[2] = ic_group_start; inp_tidx.data[3] = 0; @@ -142,7 +164,7 @@ void main() { } // Perform convolution using packed int8 dot products - for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) { + for (int ky = 0; ky < kernel_size_y; ky++) { const bool h_in_bounds = (inp_tidx.data[1] >= 0 && inp_tidx.data[1] < inp_H); // Process input channels in blocks of 4 @@ -153,10 +175,10 @@ void main() { // Reset width coordinate at start of each ic4 iteration inp_tidx.data[0] = base_inp_w; - for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) { + for (int kx = 0; kx < kernel_size_x; kx++) { // Load weight block: 4 output channels × 4 input channels // weight_block[oc] contains packed weights for ic4*4 to ic4*4+3 -> oc - const ivec4 weight_block = load_weight_block(ic4, kx, ky, oc4, IC4_per_group, conv2d_params.kernel_size.x); + const ivec4 weight_block = load_weight_block(ic4, kx, ky, oc4, IC4_per_group, kernel_size_x); // Process 4 adjacent width positions [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { @@ -187,16 +209,16 @@ void main() { } // Advance to next output position's input coordinate - inp_tidx.data[0] += conv2d_params.stride.x; + inp_tidx.data[0] += stride_x; } // Adjust for net dilation step - inp_tidx.data[0] += conv2d_params.dilation.x - 4 * conv2d_params.stride.x; + inp_tidx.data[0] += dilation_x - 4 * stride_x; } } // Advance height by dilation for next kernel row - inp_tidx.data[1] += conv2d_params.dilation.y; + inp_tidx.data[1] += dilation_y; if (get_outer_packed_dim_block_size(inp_layout) == 1) { // Advance base index by height step for next kernel row diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.yaml index 6ced1c16ebb..316e2f38fc2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.yaml @@ -8,6 +8,7 @@ q8ta_conv2d: parameter_names_with_default_values: DTYPE: float USE_INT8_DOT_PRODUCT_EXT: 1 + USE_SPEC_CONST: 0 generate_variant_forall: DTYPE: - VALUE: float @@ -15,3 +16,8 @@ q8ta_conv2d: - NAME: q8ta_conv2d - NAME: q8ta_conv2d_fallback USE_INT8_DOT_PRODUCT_EXT: 0 + - NAME: q8ta_conv2d_spec_const + USE_SPEC_CONST: 1 + - NAME: q8ta_conv2d_fallback_spec_const + USE_INT8_DOT_PRODUCT_EXT: 0 + USE_SPEC_CONST: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl index 7f4d03887df..e2d63478e97 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl @@ -20,7 +20,6 @@ layout(std430) buffer; #include "indexing.glslh" #include "common.glslh" -#include "conv2d_common.glslh" ${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=True)} ${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)} @@ -32,7 +31,6 @@ ${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False // Metadata for input/output tensors (memory layout agnostic) ${layout_declare_ubo(B, "BufferMetadata", "outp")} ${layout_declare_ubo(B, "BufferMetadata", "inp")} -${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} layout(push_constant) uniform restrict Block { float input_scale; @@ -50,6 +48,28 @@ ${layout_declare_spec_const(C, "int", "activation_type", "0")} ${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} +$if USE_SPEC_CONST: + // Conv2D parameter specialization constants + ${layout_declare_spec_const(C, "int", "kernel_size_x", "1")} + ${layout_declare_spec_const(C, "int", "kernel_size_y", "1")} + ${layout_declare_spec_const(C, "int", "stride_x", "1")} + ${layout_declare_spec_const(C, "int", "stride_y", "1")} + ${layout_declare_spec_const(C, "int", "padding_x", "0")} + ${layout_declare_spec_const(C, "int", "padding_y", "0")} + ${layout_declare_spec_const(C, "int", "dilation_x", "1")} + ${layout_declare_spec_const(C, "int", "dilation_y", "1")} +$else: + #include "conv2d_common.glslh" + ${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + #define kernel_size_x conv2d_params.kernel_size.x + #define kernel_size_y conv2d_params.kernel_size.y + #define stride_x conv2d_params.stride.x + #define stride_y conv2d_params.stride.y + #define padding_x conv2d_params.padding.x + #define padding_y conv2d_params.padding.y + #define dilation_x conv2d_params.dilation.x + #define dilation_y conv2d_params.dilation.y + #include "block_indexing.glslh" // Load a 4xint8 block of weights. @@ -89,22 +109,22 @@ void main() { } // Compute weight addressing constants - const int KW4 = int(div_up_4(conv2d_params.kernel_size.x)); + const int KW4 = int(div_up_4(kernel_size_x)); // Get strides for width and height dimensions (in texel space) const int w_stride = int(inp.strides[0][0]); const int h_stride = int(inp.strides[0][1]); // Pre-compute step sizes for efficient indexing - const int w_texel_step = conv2d_params.dilation.x * w_stride; - const int h_texel_step = conv2d_params.dilation.y * h_stride; + const int w_texel_step = dilation_x * w_stride; + const int h_texel_step = dilation_y * h_stride; // Step between adjacent output width positions in input texel space - const int subtile_w_step = conv2d_params.stride.x * w_stride; + const int subtile_w_step = stride_x * w_stride; // Compute base input position for subtile_w=0 TensorIndex4D inp_tidx; - inp_tidx.data[0] = outp_tidx.data[0] * conv2d_params.stride.x - conv2d_params.padding.x; - inp_tidx.data[1] = outp_tidx.data[1] * conv2d_params.stride.y - conv2d_params.padding.y; + inp_tidx.data[0] = outp_tidx.data[0] * stride_x - padding_x; + inp_tidx.data[1] = outp_tidx.data[1] * stride_y - padding_y; inp_tidx.data[2] = outp_tidx.data[2]; inp_tidx.data[3] = 0; // batch = 0 since N == 1 @@ -128,13 +148,13 @@ void main() { const int inp_H = int(inp.sizes[0][1]); // Perform depthwise convolution - for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) { + for (int ky = 0; ky < kernel_size_y; ky++) { const bool h_in_bounds = (inp_tidx.data[1] >= 0 && inp_tidx.data[1] < inp_H); // Reset width coordinate at start of each kernel row inp_tidx.data[0] = base_inp_w; - for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) { + for (int kx = 0; kx < kernel_size_x; kx++) { // Load weight once, reuse for all 4 width positions const int packed_weight = load_weight(kx, ky, c4, KW4, C4); const ivec4 weight_4c = unpack_int8x4(packed_weight); @@ -148,7 +168,7 @@ void main() { if (get_outer_packed_dim_block_size(inp_layout) == 1) { inp_texel_idx = base_inp_texel_idx + kx * w_texel_step + subtile_w * subtile_w_step; } else { - // const int w_offset = kx * conv2d_params.dilation.x + subtile_w * conv2d_params.stride.x; + // const int w_offset = kx * dilation_x + subtile_w * stride_x; // inp_texel_idx = base_inp_texel_idx + div_4(w_offset) * w_stride + mod_4(w_offset); // inp_texel_idx = tensor4d_idx_to_texel_idx(inp, inp_tidx, inp_layout); const int w4 = div_4(inp_tidx.data[0]); @@ -162,15 +182,15 @@ void main() { acc[subtile_w] += weight_4c * input_4c; // Advance to next output position's input coordinate - inp_tidx.data[0] += conv2d_params.stride.x; + inp_tidx.data[0] += stride_x; } // We advanced by 4*stride.x during subtile loop; adjust for net dilation step - inp_tidx.data[0] += conv2d_params.dilation.x - 4 * conv2d_params.stride.x; + inp_tidx.data[0] += dilation_x - 4 * stride_x; } // Advance height by dilation for next kernel row - inp_tidx.data[1] += conv2d_params.dilation.y; + inp_tidx.data[1] += dilation_y; if (get_outer_packed_dim_block_size(inp_layout) == 1) { // Advance base index by height step for next kernel row diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.yaml index 5b671e1e8d5..31c25686b26 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.yaml @@ -7,8 +7,11 @@ q8ta_conv2d_dw: parameter_names_with_default_values: DTYPE: float + USE_SPEC_CONST: 0 generate_variant_forall: DTYPE: - VALUE: float shader_variants: - NAME: q8ta_conv2d_dw + - NAME: q8ta_conv2d_dw_spec_const + USE_SPEC_CONST: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl index fc063579c45..30e70aa2271 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl @@ -50,14 +50,22 @@ ${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False ${layout_declare_ubo(B, "BufferMetadata", "outp")} ${layout_declare_ubo(B, "BufferMetadata", "inp")} -layout(push_constant) uniform restrict Block { - float input_scale; - int input_zp; - float output_inv_scale; - int output_zp; - int K4_per_group; - int OC4_per_group; -}; +$if USE_SPEC_CONST: + layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; + }; +$else: + layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; + int K4_per_group; + int OC4_per_group; + }; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -68,6 +76,10 @@ ${layout_declare_spec_const(C, "int", "activation_type", "0")} ${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} +$if USE_SPEC_CONST: + ${layout_declare_spec_const(C, "int", "K4_per_group", "1")} + ${layout_declare_spec_const(C, "int", "OC4_per_group", "1")} + int compute_outp_buffer_idx( const int w_block_idx, const int h_idx, diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.yaml index 46670b8d2aa..e214298629b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.yaml @@ -8,6 +8,7 @@ q8ta_conv2d_pw: parameter_names_with_default_values: DTYPE: float USE_INT8_DOT_PRODUCT_EXT: 1 + USE_SPEC_CONST: 0 generate_variant_forall: DTYPE: - VALUE: float @@ -15,3 +16,8 @@ q8ta_conv2d_pw: - NAME: q8ta_conv2d_pw - NAME: q8ta_conv2d_pw_fallback USE_INT8_DOT_PRODUCT_EXT: 0 + - NAME: q8ta_conv2d_pw_spec_const + USE_SPEC_CONST: 1 + - NAME: q8ta_conv2d_pw_fallback_spec_const + USE_INT8_DOT_PRODUCT_EXT: 0 + USE_SPEC_CONST: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl index ed4e124ac45..730d145a22d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl @@ -15,7 +15,6 @@ layout(std430) buffer; #include "indexing.glslh" -#include "conv2d_common.glslh" ${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)} @@ -23,7 +22,6 @@ ${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scala // Metadata for im2col output and input tensors (layout-agnostic) ${layout_declare_ubo(B, "BufferMetadata", "im2col_outp")} ${layout_declare_ubo(B, "BufferMetadata", "inp")} -${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} ${layout_declare_spec_const(C, "int", "apply_bias", "1")} @@ -31,6 +29,30 @@ ${layout_declare_spec_const(C, "int", "apply_bias", "1")} ${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} +$if USE_SPEC_CONST: + // Conv2D parameter specialization constants + ${layout_declare_spec_const(C, "int", "kernel_size_x", "1")} + ${layout_declare_spec_const(C, "int", "stride_x", "1")} + ${layout_declare_spec_const(C, "int", "stride_y", "1")} + ${layout_declare_spec_const(C, "int", "padding_x", "0")} + ${layout_declare_spec_const(C, "int", "padding_y", "0")} + ${layout_declare_spec_const(C, "int", "dilation_x", "1")} + ${layout_declare_spec_const(C, "int", "dilation_y", "1")} + ${layout_declare_spec_const(C, "int", "in_channels_per_group", "1")} + ${layout_declare_spec_const(C, "int", "K_per_group", "1")} +$else: + #include "conv2d_common.glslh" + ${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + #define kernel_size_x conv2d_params.kernel_size.x + #define stride_x conv2d_params.stride.x + #define stride_y conv2d_params.stride.y + #define padding_x conv2d_params.padding.x + #define padding_y conv2d_params.padding.y + #define dilation_x conv2d_params.dilation.x + #define dilation_y conv2d_params.dilation.y + #define in_channels_per_group conv2d_params.in_channels_per_group + #define K_per_group conv2d_params.K_per_group + layout(push_constant) uniform restrict Block { int zp; }; @@ -64,23 +86,23 @@ void main() { const int im2col_h = h_idx; const int im2col_k = mul_4(c4_idx); - const int group_idx = im2col_k / conv2d_params.K_per_group; - const int k_in_group = im2col_k % conv2d_params.K_per_group; + const int group_idx = im2col_k / K_per_group; + const int k_in_group = im2col_k % K_per_group; - const int c_in_group = k_in_group % conv2d_params.in_channels_per_group; - const int krow = k_in_group / conv2d_params.in_channels_per_group; - const int kernel_x = krow % conv2d_params.kernel_size.x; - const int kernel_y = krow / conv2d_params.kernel_size.x; + const int c_in_group = k_in_group % in_channels_per_group; + const int krow = k_in_group / in_channels_per_group; + const int kernel_x = krow % kernel_size_x; + const int kernel_y = krow / kernel_size_x; // Base input position const int input_x_base = - (im2col_w * conv2d_params.stride.x) - conv2d_params.padding.x + - (kernel_x * conv2d_params.dilation.x); + (im2col_w * stride_x) - padding_x + + (kernel_x * dilation_x); const int input_y = - (im2col_h * conv2d_params.stride.y) - conv2d_params.padding.y + - (kernel_y * conv2d_params.dilation.y); + (im2col_h * stride_y) - padding_y + + (kernel_y * dilation_y); const int input_z = - group_idx * conv2d_params.in_channels_per_group + c_in_group; + group_idx * in_channels_per_group + c_in_group; // Input tensor extents const int input_W = input_sizes.x; @@ -98,7 +120,7 @@ void main() { // Each loaded int contains 4 packed int8 channel values. ivec4 im2col_block; for (int i = 0; i < 4; i++) { - const int x = input_x_base + i * conv2d_params.stride.x; + const int x = input_x_base + i * stride_x; if (!y_z_in_bounds || x < 0 || x >= input_W) { im2col_block[i] = zp_packed; } else { diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.yaml index 08ce5d59d35..e4174787537 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.yaml @@ -7,5 +7,8 @@ q8ta_im2col: parameter_names_with_default_values: DTYPE: float + USE_SPEC_CONST: 0 shader_variants: - NAME: q8ta_im2col + - NAME: q8ta_im2col_spec_const + USE_SPEC_CONST: 1 diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp index f6e89bef03d..25015d99090 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp @@ -241,7 +241,8 @@ void add_q8ta_conv2d_node( const ValueRef dilation, const ValueRef groups, const uint32_t activation_type, - const ValueRef packed_int8_output) { + const ValueRef packed_int8_output, + const bool spec_const) { (void)packed_int8_input_im2col; // Not used in general shader Conv2DParams conv_params = create_conv2d_params( @@ -291,15 +292,20 @@ void add_q8ta_conv2d_node( const bool use_hw_dot = graph.context()->adapter_ptr()->supports_int8_dot_product(); std::string kernel_name = use_hw_dot ? "q8ta_conv2d" : "q8ta_conv2d_fallback"; + if (spec_const) { + kernel_name += "_spec_const"; + } add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); // Pass metadata for both output and input tensors vkapi::ParamsBindList param_buffers = { graph.buffer_meta_ubo(packed_int8_output), - graph.buffer_meta_ubo(packed_int8_input), - graph.create_params_buffer(conv_params)}; + graph.buffer_meta_ubo(packed_int8_input)}; + if (!spec_const) { + param_buffers.append(graph.create_params_buffer(conv_params)); + } - // Build spec constants: apply_bias, apply_relu + layout constants + // Build spec constants: apply_bias, activation_type, layout constants vkapi::SpecVarList spec_constants = { apply_bias, activation_type, @@ -308,6 +314,23 @@ void add_q8ta_conv2d_node( graph.hashed_layout_of(packed_int8_output), }; + if (spec_const) { + // Conv2D parameter specialization constants + spec_constants.append( + static_cast(conv_params.kernel_size[0])); + spec_constants.append( + static_cast(conv_params.kernel_size[1])); + spec_constants.append(static_cast(conv_params.stride[0])); + spec_constants.append(static_cast(conv_params.stride[1])); + spec_constants.append(static_cast(conv_params.padding[0])); + spec_constants.append(static_cast(conv_params.padding[1])); + spec_constants.append( + static_cast(conv_params.dilation[0])); + spec_constants.append( + static_cast(conv_params.dilation[1])); + spec_constants.append(static_cast(conv_params.groups)); + } + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -417,6 +440,83 @@ void q8ta_conv2d_general( packed_int8_output); } +void q8ta_conv2d_general_spec_const( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + + QuantizationConfig weight_quant_config(8, kPerChannel, {}); + + ValueRef packed_weight = prepack_quantized_conv2d_weight( + graph, + weight_quant_config, + weight_data, + packed_int8_input, + packed_int8_output, + groups, + kernel_size); + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(weight_scales_data), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + add_q8ta_conv2d_node( + graph, + packed_int8_input, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + activation_type_val, + packed_int8_output, + /*spec_const=*/true); +} + void q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { const ValueRef input = args.at(0); const ValueRef groups_ref = args.at(13); @@ -453,9 +553,44 @@ void q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { } } +void q8ta_conv2d_spec_const( + ComputeGraph& graph, + const std::vector& args) { + const ValueRef input = args.at(0); + const ValueRef groups_ref = args.at(13); + const ValueRef output = args.at(15); + + const int64_t groups = graph.extract_scalar(groups_ref); + const int64_t in_channels = graph.size_at(-3, input); + const int64_t in_channels_per_group = in_channels / groups; + + const int64_t H_out = graph.size_at(-2, output); + const int64_t W_out = graph.size_at(-1, output); + const int64_t spatial_out = H_out * W_out; + + const bool im2col_eligible = in_channels_per_group % 4 == 0; + + bool use_im2col = false; + if (graph.device_is_mali()) { + use_im2col = im2col_eligible; + } else { + use_im2col = im2col_eligible && groups == 1 && + (in_channels_per_group >= 32 || spatial_out <= 4096); + } + + if (use_im2col) { + q8ta_conv2d_im2col_spec_const(graph, args); + } else { + q8ta_conv2d_general_spec_const(graph, args); + } +} + REGISTER_OPERATORS { VK_REGISTER_OP(et_vk.q8ta_conv2d.default, q8ta_conv2d); VK_REGISTER_OP(et_vk.q8ta_conv2d_general.default, q8ta_conv2d_general); + VK_REGISTER_OP(et_vk.q8ta_conv2d.spec_const, q8ta_conv2d_spec_const); + VK_REGISTER_OP( + et_vk.q8ta_conv2d_general.spec_const, q8ta_conv2d_general_spec_const); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h index f463589c50a..801984ff7a6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h @@ -67,7 +67,8 @@ void add_q8ta_conv2d_dw_node( const ValueRef dilation, const ValueRef groups, const uint32_t activation_type, - const ValueRef packed_int8_output); + const ValueRef packed_int8_output, + const bool spec_const = false); void add_conv2d_dw_q8ta_q8csw_q8to_4w4c_node( ComputeGraph& graph, @@ -107,7 +108,8 @@ void add_q8ta_conv2d_node( const ValueRef dilation, const ValueRef groups, const uint32_t activation_type, - const ValueRef packed_int8_output); + const ValueRef packed_int8_output, + const bool spec_const = false); void add_q8ta_conv2d_pw_node( ComputeGraph& graph, @@ -123,7 +125,8 @@ void add_q8ta_conv2d_pw_node( const ValueRef packed_bias, const uint32_t activation_type, const ValueRef packed_int8_output, - const int32_t groups = 1); + const int32_t groups = 1, + const bool spec_const = false); std::vector calculate_q8ta_im2col_sizes( ComputeGraph* graph, @@ -142,10 +145,15 @@ void add_q8ta_im2col_node( const ValueRef groups, const ValueRef packed_int8_output, const ValueRef packed_int8_im2col, - const int32_t zp); + const int32_t zp, + const bool spec_const = false); void q8ta_conv2d_im2col(ComputeGraph& graph, const std::vector& args); +void q8ta_conv2d_im2col_spec_const( + ComputeGraph& graph, + const std::vector& args); + // Transposed convolution void q8ta_conv2d_transposed( diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp index e690ff435a8..1d996860464 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp @@ -282,7 +282,8 @@ void add_q8ta_conv2d_dw_node( const ValueRef dilation, const ValueRef groups, const uint32_t activation_type, - const ValueRef packed_int8_output) { + const ValueRef packed_int8_output, + const bool spec_const) { Conv2DParams conv_params = create_conv2d_params( graph, packed_int8_input, @@ -327,15 +328,20 @@ void add_q8ta_conv2d_dw_node( }; std::string kernel_name = "q8ta_conv2d_dw"; + if (spec_const) { + kernel_name += "_spec_const"; + } add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); // Pass metadata for both output and input tensors vkapi::ParamsBindList param_buffers = { graph.buffer_meta_ubo(packed_int8_output), - graph.buffer_meta_ubo(packed_int8_input), - graph.create_params_buffer(conv_params)}; + graph.buffer_meta_ubo(packed_int8_input)}; + if (!spec_const) { + param_buffers.append(graph.create_params_buffer(conv_params)); + } - // Build spec constants: apply_bias, activation_type + layout constants + // Build spec constants: apply_bias, activation_type, layout constants vkapi::SpecVarList spec_constants = { apply_bias, activation_type, @@ -344,6 +350,22 @@ void add_q8ta_conv2d_dw_node( graph.hashed_layout_of(packed_int8_output), }; + if (spec_const) { + // Conv2D parameter specialization constants + spec_constants.append( + static_cast(conv_params.kernel_size[0])); + spec_constants.append( + static_cast(conv_params.kernel_size[1])); + spec_constants.append(static_cast(conv_params.stride[0])); + spec_constants.append(static_cast(conv_params.stride[1])); + spec_constants.append(static_cast(conv_params.padding[0])); + spec_constants.append(static_cast(conv_params.padding[1])); + spec_constants.append( + static_cast(conv_params.dilation[0])); + spec_constants.append( + static_cast(conv_params.dilation[1])); + } + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -442,8 +464,79 @@ void q8ta_conv2d_dw(ComputeGraph& graph, const std::vector& args) { packed_int8_output); } +void q8ta_conv2d_dw_spec_const( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + + QuantizationConfig weight_quant_config(8, kPerChannel, {}); + + ValueRef packed_weight = prepack_quantized_conv2d_dw_weight( + graph, weight_quant_config, weight_data, kernel_size); + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(weight_scales_data), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + add_q8ta_conv2d_dw_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + activation_type_val, + packed_int8_output, + /*spec_const=*/true); +} + REGISTER_OPERATORS { VK_REGISTER_OP(et_vk.q8ta_conv2d_dw.default, q8ta_conv2d_dw); + VK_REGISTER_OP(et_vk.q8ta_conv2d_dw.spec_const, q8ta_conv2d_dw_spec_const); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp index b43fe9eacc6..5e97b05fa2b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp @@ -109,7 +109,8 @@ void add_q8ta_im2col_node( const ValueRef groups, const ValueRef packed_int8_output, const ValueRef packed_int8_im2col, - const int32_t zp) { + const int32_t zp, + const bool spec_const) { // Validate packed dim info for input and output tensors VK_CHECK_COND(q8ta_conv2d_check_packed_dim_info( graph.packed_dim_info_of(packed_int8_input))); @@ -132,23 +133,46 @@ void add_q8ta_im2col_node( VK_CHECK_COND(conv_params.in_channels_per_group % 4 == 0); std::string kernel_name = "q8ta_im2col"; + if (spec_const) { + kernel_name += "_spec_const"; + } vkapi::ParamsBindList param_buffers = { graph.buffer_meta_ubo(packed_int8_im2col), - graph.buffer_meta_ubo(packed_int8_input), - graph.create_params_buffer(conv_params)}; + graph.buffer_meta_ubo(packed_int8_input)}; + if (!spec_const) { + param_buffers.append(graph.create_params_buffer(conv_params)); + } std::vector push_constants = { PushConstantDataInfo(&zp, sizeof(zp)), }; - // Build spec constants: apply_bias + layout constants (for generic shader) + // Build spec constants: apply_bias, layout constants vkapi::SpecVarList spec_constants = { 1u, graph.hashed_layout_of(packed_int8_im2col), graph.hashed_layout_of(packed_int8_input), }; + if (spec_const) { + // Conv2D parameter specialization constants + spec_constants.append( + static_cast(conv_params.kernel_size[0])); + spec_constants.append(static_cast(conv_params.stride[0])); + spec_constants.append(static_cast(conv_params.stride[1])); + spec_constants.append(static_cast(conv_params.padding[0])); + spec_constants.append(static_cast(conv_params.padding[1])); + spec_constants.append( + static_cast(conv_params.dilation[0])); + spec_constants.append( + static_cast(conv_params.dilation[1])); + spec_constants.append( + static_cast(conv_params.in_channels_per_group)); + spec_constants.append( + static_cast(conv_params.K_per_group)); + } + // // Add layout specialization constants (only for generic shader) // if (!use_4w4c_path) { // spec_constants.append(graph.hashed_layout_of(packed_int8_input)); @@ -275,8 +299,103 @@ void q8ta_conv2d_im2col( groups_val); } +void q8ta_conv2d_im2col_spec_const( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + QuantizationConfig weight_quant_config(8, kPerChannel, {}); + + ValueRef packed_weight = + prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(weight_scales_data), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + + std::vector im2col_sizes = calculate_q8ta_im2col_sizes( + &graph, packed_int8_input, packed_int8_output, kernel_size, groups); + + TmpTensor packed_int8_im2col( + &graph, + im2col_sizes, + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + int32_t zp = graph.extract_scalar(input_zp); + + add_q8ta_im2col_node( + graph, + packed_int8_input, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output, + packed_int8_im2col, + zp, + /*spec_const=*/true); + + const int32_t groups_val = graph.extract_scalar(groups); + + add_q8ta_conv2d_pw_node( + graph, + packed_int8_im2col, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + activation_type_val, + packed_int8_output, + groups_val, + /*spec_const=*/true); +} + REGISTER_OPERATORS { VK_REGISTER_OP(et_vk.q8ta_conv2d_im2col.default, q8ta_conv2d_im2col); + VK_REGISTER_OP( + et_vk.q8ta_conv2d_im2col.spec_const, q8ta_conv2d_im2col_spec_const); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp index e27e0699dac..f7b67387ac4 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp @@ -201,7 +201,8 @@ void add_q8ta_conv2d_pw_node( const ValueRef packed_bias, const uint32_t activation_type, const ValueRef packed_int8_output, - const int32_t groups) { + const int32_t groups, + const bool spec_const) { VK_CHECK_COND(q8ta_conv2d_check_4w4c_packed_dim_info( graph.packed_dim_info_of(packed_int8_input))); VK_CHECK_COND(q8ta_conv2d_check_packed_dim_info( @@ -232,26 +233,37 @@ void add_q8ta_conv2d_pw_node( PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), - PushConstantDataInfo(&K4_per_group, sizeof(K4_per_group)), - PushConstantDataInfo(&OC4_per_group, sizeof(OC4_per_group)), }; + if (!spec_const) { + push_constants.push_back( + PushConstantDataInfo(&K4_per_group, sizeof(K4_per_group))); + push_constants.push_back( + PushConstantDataInfo(&OC4_per_group, sizeof(OC4_per_group))); + } const bool use_hw_dot = graph.context()->adapter_ptr()->supports_int8_dot_product(); std::string kernel_name = use_hw_dot ? "q8ta_conv2d_pw" : "q8ta_conv2d_pw_fallback"; + if (spec_const) { + kernel_name += "_spec_const"; + } add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); vkapi::ParamsBindList param_buffers = { graph.buffer_meta_ubo(packed_int8_output), graph.buffer_meta_ubo(packed_int8_input)}; - vkapi::SpecVarList spec_constants = { + vkapi::SpecVarList spec_constants_list = { apply_bias, activation_type, graph.hashed_layout_of(packed_int8_output), graph.hashed_layout_of(packed_int8_input), }; + if (spec_const) { + spec_constants_list.append(static_cast(K4_per_group)); + spec_constants_list.append(static_cast(OC4_per_group)); + } graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, @@ -267,7 +279,7 @@ void add_q8ta_conv2d_pw_node( vkapi::kRead}}, param_buffers, push_constants, - spec_constants, + spec_constants_list, {})); } @@ -347,8 +359,81 @@ void q8ta_conv2d_pw(ComputeGraph& graph, const std::vector& args) { packed_int8_output); } +void q8ta_conv2d_pw_spec_const( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + // Accept but ignore conv params - pointwise has fixed kernel=1x1, stride=1, + // padding=0, dilation=1, groups=1 + (void)args.at(idx++); // kernel_size + (void)args.at(idx++); // stride + (void)args.at(idx++); // padding + (void)args.at(idx++); // dilation + (void)args.at(idx++); // groups + const ValueRef activation_ref = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation_ref))); + + QuantizationConfig weight_quant_config(8, kPerChannel, {}); + + ValueRef packed_weight = prepack_quantized_conv2d_pw_weight( + graph, + weight_quant_config, + weight_data, + packed_int8_input, + packed_int8_output); + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(weight_scales_data), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + add_q8ta_conv2d_pw_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + activation_type_val, + packed_int8_output, + /*groups=*/1, + /*spec_const=*/true); +} + REGISTER_OPERATORS { VK_REGISTER_OP(et_vk.q8ta_conv2d_pw.default, q8ta_conv2d_pw); + VK_REGISTER_OP(et_vk.q8ta_conv2d_pw.spec_const, q8ta_conv2d_pw_spec_const); } } // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp index 679ac33d11b..efb562f5862 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp @@ -79,6 +79,25 @@ void test_q8ta_conv2d_dw( groups, packed_int8_output}; VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); + } else if (impl_selector == "spec_const") { + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + activation, + packed_int8_output}; + VK_GET_OP_FN("et_vk.q8ta_conv2d_dw.spec_const")(graph, conv_args); } else { std::vector conv_args = { packed_int8_input, @@ -190,6 +209,12 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { VK_GET_OP_FN("et_vk.q8ta_conv2d_im2col.default")(graph, conv_args); } else if (impl_selector == "general") { VK_GET_OP_FN("et_vk.q8ta_conv2d_general.default")(graph, conv_args); + } else if (impl_selector == "spec_const") { + VK_GET_OP_FN("et_vk.q8ta_conv2d.spec_const")(graph, conv_args); + } else if (impl_selector == "general_spec_const") { + VK_GET_OP_FN("et_vk.q8ta_conv2d_general.spec_const")(graph, conv_args); + } else if (impl_selector == "im2col_spec_const") { + VK_GET_OP_FN("et_vk.q8ta_conv2d_im2col.spec_const")(graph, conv_args); } else { VK_GET_OP_FN("et_vk.q8ta_conv2d.default")(graph, conv_args); } @@ -269,6 +294,25 @@ void test_q8ta_conv2d_pw( groups, packed_int8_output}; VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); + } else if (impl_selector == "spec_const") { + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + activation, + packed_int8_output}; + VK_GET_OP_FN("et_vk.q8ta_conv2d_pw.spec_const")(graph, conv_args); } else { std::vector conv_args = { packed_int8_input, diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp index 9f0273a5b83..82e13a8fc00 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp @@ -236,6 +236,9 @@ std::vector generate_quantized_conv2d_easy_cases() { make_test_case_name(config, false, fp_storage_type, utils::kBuffer); test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, fp_storage_type, int8_memory_layout)); + test_cases.push_back(create_test_case_from_config( + config, vkapi::kFloat, fp_storage_type, int8_memory_layout, + /*impl_selector=*/"spec_const")); // Test im2col implementation when input channels per group is a // multiple of 4 @@ -440,6 +443,12 @@ static std::vector generate_quantized_conv2d_test_cases() { fp_storage_type, int8_memory_layout, /*impl_selector=*/"general")); + test_cases.push_back(create_test_case_from_config( + config, + vkapi::kFloat, + fp_storage_type, + int8_memory_layout, + /*impl_selector=*/"general_spec_const")); // Test im2col implementation when input channels per group is a // multiple of 4 @@ -452,6 +461,12 @@ static std::vector generate_quantized_conv2d_test_cases() { fp_storage_type, int8_memory_layout, /*impl_selector=*/"im2col")); + test_cases.push_back(create_test_case_from_config( + config, + vkapi::kFloat, + fp_storage_type, + int8_memory_layout, + /*impl_selector=*/"im2col_spec_const")); } // For 4W4C layout, also test the legacy implementation diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp index 0734e444d57..cf521653229 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp @@ -246,6 +246,9 @@ std::vector generate_quantized_conv2d_dw_easy_cases() { make_test_case_name(config, false, fp_storage_type, utils::kBuffer); test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, fp_storage_type, int8_memory_layout)); + test_cases.push_back(create_test_case_from_config( + config, vkapi::kFloat, fp_storage_type, int8_memory_layout, + /*impl_selector=*/"spec_const")); // For 4W4C layout, also test the legacy implementation if (int8_memory_layout == utils::kPackedInt8_4W4C) { @@ -376,6 +379,9 @@ std::vector generate_quantized_conv2d_dw_test_cases() { config, is_performance, fp_storage_type, utils::kBuffer); test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, fp_storage_type, int8_memory_layout)); + test_cases.push_back(create_test_case_from_config( + config, vkapi::kFloat, fp_storage_type, int8_memory_layout, + /*impl_selector=*/"spec_const")); // For 4W4C layout, also test the legacy implementation if (int8_memory_layout == utils::kPackedInt8_4W4C) { diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp index 83b9f92fb3a..d57c6eb68c8 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp @@ -334,6 +334,12 @@ static std::vector generate_quantized_conv2d_pw_test_cases() { config, is_performance, fp_storage_type, utils::kBuffer); test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, fp_storage_type, int8_memory_layout)); + test_cases.push_back(create_test_case_from_config( + config, + vkapi::kFloat, + fp_storage_type, + int8_memory_layout, + /*impl_selector=*/"spec_const")); // For 4W4C layout, also test the legacy implementation if (int8_memory_layout == utils::kPackedInt8_4W4C) {