Skip to content

Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback#28416

Draft
melkap01-Arm wants to merge 2 commits intomicrosoft:mainfrom
melkap01-Arm:fp8_DynamicQuantMatMul_Support
Draft

Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback#28416
melkap01-Arm wants to merge 2 commits intomicrosoft:mainfrom
melkap01-Arm:fp8_DynamicQuantMatMul_Support

Conversation

@melkap01-Arm
Copy link
Copy Markdown
Contributor

Description

Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback

This MR adds a CPU contrib implementation for com.microsoft::DynamicQuantMatMulFp8. The path supports dynamic quantization of float/float16/bfloat16 activations to FP8, FP8 or constant pre-quantized B weights, block-wise scales, configurable block sizes, and float/float16/bfloat16 outputs.

Main Changes

Adds DynamicQuantMatMulFp8 schema under the Microsoft contrib opset.
Registers the CPU contrib kernel when FP8 types are enabled.
Adds dynamic_quant_matmul_fp8.{h,cc} CPU opkernel implementation.
Adds MlasFp8GemmBatch and its scalar qgemm_fp8.cpp fallback implementation, which performs the FP8 GEMM compute path used by the DynamicQuantMatMulFp8 CPU kernel.
Wires the MLAS FP8 source into the MLAS build.
Adds provider tests for the FP8 op-kernel path.

Operator Contract

A supports float, float16, and bfloat16.
Runtime B supports FP8 only.
Constant initializer B supports float, float16, bfloat16, or FP8.
Non-FP8 constant B is quantized once during PrePack.
Dynamic non-FP8 B is intentionally rejected.
Output Y supports float, float16, and bfloat16.
Scale tensors support float, float16, and bfloat16.
FP8 formats supported:

FLOAT8E4M3FN
FLOAT8E4M3FNUZ
FLOAT8E5M2
FLOAT8E5M2FNUZ

Quantization Semantics

The implementation enforces symmetric quantization.
All A/B/Y zero-point inputs, when provided, must encode 0.0.
Non-zero zero points are rejected.
Scale values are validated as finite and positive before use.
Y_scale and Y_zero_point are optional schema inputs.
Y_scale, when provided, must be scalar and is applied to the final accumulation.
Y_zero_point, when provided, must be scalar and zero-valued.

Block Layout

Adds block_size_m, block_size_k, and block_size_n attributes, all defaulting to 128.
A scale / zero-point layout is validated against ceil(M / block_size_m) and K / block_size_k.
B scale / zero-point layout is validated against K / block_size_k and N / block_size_n.
Rank-2 A scale / zero-point tensors are allowed as shared tensors across GEMMs.
Batched A scale / zero-point tensors must match the output GEMM batch layout currently supported by the kernel.
Shape inference was tightened to match runtime behavior and avoid accepting unsupported broadcasted A-scale layouts.

Kernel Behavior

Runtime FP8 B is consumed directly.
Constant non-FP8 B is converted to FP8 in PrePack.
Prepacked B metadata restores B shape, FP8 type, and packed buffer size for shared prepack reuse.
FP8 type consistency is validated across A/B and B/B-zero-point.
Runtime B rank is restricted to 2D for the non-prepacked path.
K == 0 produces zero-filled output instead of returning uninitialized data.
M == 0 and N == 0 empty outputs return cleanly.

MLAS FP8 Fallback

Adds MlasFp8GemmBatch / MlasFp8Gemm API.
Implements FP8 decode, scale application, float accumulation, optional output scaling, and output zero-point handling.
Supports all four FP8 modes listed above.
Parallelizes fallback work over BatchN * M.
Adds defensive validation before threaded execution:

valid FP8 mode
non-zero block sizes
required pointers only when actually dereferenced
leading dimensions only when used
strided offset overflow checks
block scale / zero-point offset overflow checks
public block-count validation against shape-derived block counts

This is a functional scalar fallback, not a hardware-optimized FP8 GEMM backend.

Tests Provider tests cover:

Constant non-FP8 B prepack path.
Runtime FP8 B path.
Omitted optional output quantization inputs.
Optional Y_scale.
Float16 and bfloat16 outputs.
Bfloat16 scale tensors.
Symmetric zero-point rejection for A/B/Y.
FP8 B / B-zero-point type mismatch rejection.
Non-default block sizes and partial M blocks.
Shared prepacked B metadata restore.
Shared prepack semantic correctness with different B scales.
Rejection of unsupported dynamic non-FP8 B.
Batched A-scale layout rejection.
Malformed A/B scale shape validation before scale reads.
M == 0, N == 0, and K == 0 edge cases.
Invalid Y_scale shape, value, and type on the K == 0 path.

Known Limitations

No dynamic non-FP8 B support by design.
No packed-B optimized FP8 backend is exposed yet.
No KleidiAI FP8 dispatch is included in this path.
MLAS FP8 GEMM is currently correctness-oriented scalar fallback code, not a production performance kernel.
Full MatMul broadcast semantics for batched A scale tensors are not implemented; schema/runtime validation is tightened to the currently supported layout.

Verification

Built onnxruntime_provider_test.
Built onnxruntime_mlas_test.
result:
Passed
Ran on Qwen3 model (converted to .onnx version)
result:
All DynamicQuantMatMulFp8 tests passed.

Motivation and Context

Adds DynamicQuantMatMulFp8 schema under the Microsoft contrib opset.
Registers the CPU contrib kernel when FP8 types are enabled.
Adds dynamic_quant_matmul_fp8.{h,cc} CPU kernel implementation.
Adds MLAS FP8 GEMM API surface and scalar fallback implementation in qgemm_fp8.cpp.
Wires the MLAS FP8 source into the MLAS build.
Adds provider tests for the FP8 op-kernel path.

Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant