Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback#28416
Draft
melkap01-Arm wants to merge 2 commits intomicrosoft:mainfrom
Draft
Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback#28416melkap01-Arm wants to merge 2 commits intomicrosoft:mainfrom
melkap01-Arm wants to merge 2 commits intomicrosoft:mainfrom
Conversation
Adds DynamicQuantMatMulFp8 schema under the Microsoft contrib opset.
Registers the CPU contrib kernel when FP8 types are enabled.
Adds dynamic_quant_matmul_fp8.{h,cc} CPU kernel implementation.
Adds MLAS FP8 GEMM API surface and scalar fallback implementation in qgemm_fp8.cpp.
Wires the MLAS FP8 source into the MLAS build.
Adds provider tests for the FP8 op-kernel path.
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback
This MR adds a CPU contrib implementation for com.microsoft::DynamicQuantMatMulFp8. The path supports dynamic quantization of float/float16/bfloat16 activations to FP8, FP8 or constant pre-quantized B weights, block-wise scales, configurable block sizes, and float/float16/bfloat16 outputs.
Main Changes
Adds DynamicQuantMatMulFp8 schema under the Microsoft contrib opset.
Registers the CPU contrib kernel when FP8 types are enabled.
Adds dynamic_quant_matmul_fp8.{h,cc} CPU opkernel implementation.
Adds MlasFp8GemmBatch and its scalar qgemm_fp8.cpp fallback implementation, which performs the FP8 GEMM compute path used by the DynamicQuantMatMulFp8 CPU kernel.
Wires the MLAS FP8 source into the MLAS build.
Adds provider tests for the FP8 op-kernel path.
Operator Contract
A supports float, float16, and bfloat16.
Runtime B supports FP8 only.
Constant initializer B supports float, float16, bfloat16, or FP8.
Non-FP8 constant B is quantized once during PrePack.
Dynamic non-FP8 B is intentionally rejected.
Output Y supports float, float16, and bfloat16.
Scale tensors support float, float16, and bfloat16.
FP8 formats supported:
FLOAT8E4M3FN
FLOAT8E4M3FNUZ
FLOAT8E5M2
FLOAT8E5M2FNUZ
Quantization Semantics
The implementation enforces symmetric quantization.
All A/B/Y zero-point inputs, when provided, must encode 0.0.
Non-zero zero points are rejected.
Scale values are validated as finite and positive before use.
Y_scale and Y_zero_point are optional schema inputs.
Y_scale, when provided, must be scalar and is applied to the final accumulation.
Y_zero_point, when provided, must be scalar and zero-valued.
Block Layout
Adds block_size_m, block_size_k, and block_size_n attributes, all defaulting to 128.
A scale / zero-point layout is validated against ceil(M / block_size_m) and K / block_size_k.
B scale / zero-point layout is validated against K / block_size_k and N / block_size_n.
Rank-2 A scale / zero-point tensors are allowed as shared tensors across GEMMs.
Batched A scale / zero-point tensors must match the output GEMM batch layout currently supported by the kernel.
Shape inference was tightened to match runtime behavior and avoid accepting unsupported broadcasted A-scale layouts.
Kernel Behavior
Runtime FP8 B is consumed directly.
Constant non-FP8 B is converted to FP8 in PrePack.
Prepacked B metadata restores B shape, FP8 type, and packed buffer size for shared prepack reuse.
FP8 type consistency is validated across A/B and B/B-zero-point.
Runtime B rank is restricted to 2D for the non-prepacked path.
K == 0 produces zero-filled output instead of returning uninitialized data.
M == 0 and N == 0 empty outputs return cleanly.
MLAS FP8 Fallback
Adds MlasFp8GemmBatch / MlasFp8Gemm API.
Implements FP8 decode, scale application, float accumulation, optional output scaling, and output zero-point handling.
Supports all four FP8 modes listed above.
Parallelizes fallback work over BatchN * M.
Adds defensive validation before threaded execution:
valid FP8 mode
non-zero block sizes
required pointers only when actually dereferenced
leading dimensions only when used
strided offset overflow checks
block scale / zero-point offset overflow checks
public block-count validation against shape-derived block counts
This is a functional scalar fallback, not a hardware-optimized FP8 GEMM backend.
Tests Provider tests cover:
Constant non-FP8 B prepack path.
Runtime FP8 B path.
Omitted optional output quantization inputs.
Optional Y_scale.
Float16 and bfloat16 outputs.
Bfloat16 scale tensors.
Symmetric zero-point rejection for A/B/Y.
FP8 B / B-zero-point type mismatch rejection.
Non-default block sizes and partial M blocks.
Shared prepacked B metadata restore.
Shared prepack semantic correctness with different B scales.
Rejection of unsupported dynamic non-FP8 B.
Batched A-scale layout rejection.
Malformed A/B scale shape validation before scale reads.
M == 0, N == 0, and K == 0 edge cases.
Invalid Y_scale shape, value, and type on the K == 0 path.
Known Limitations
No dynamic non-FP8 B support by design.
No packed-B optimized FP8 backend is exposed yet.
No KleidiAI FP8 dispatch is included in this path.
MLAS FP8 GEMM is currently correctness-oriented scalar fallback code, not a production performance kernel.
Full MatMul broadcast semantics for batched A scale tensors are not implemented; schema/runtime validation is tightened to the currently supported layout.
Verification
Built onnxruntime_provider_test.
Built onnxruntime_mlas_test.
result:
Passed
Ran on Qwen3 model (converted to .onnx version)
result:
All DynamicQuantMatMulFp8 tests passed.
Motivation and Context