[OMNIML-3349] Add FP8 MHA quantization support for HuggingFace ViT#1289
[OMNIML-3349] Add FP8 MHA quantization support for HuggingFace ViT#1289
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds FP8-focused ONNX export and post-processing: constant/transpose folding into FP8 weights, two attention-aware Q/DQ graph rewrites, FP16/FP32 scale-cast folding utilities, simplified FP8 export helpers, LayerNorm/Softmax quant-module registrations, and HF nested-attention detection tweaks. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/torch_onnx/vit_mha_quantization.py`:
- Around line 267-298: The current loop counts any MatMul consuming a
DequantizeLinear output (matmul_with_qdq) which falsely includes projection/MLP
matmuls; replace this with an attention-specific check: implement a helper
(e.g., is_attention_matmul(node, output_to_node, graph)) and only increment
matmul_with_qdq when that returns true. Make is_attention_matmul examine the
MatMul node name and upstream pattern (check parent ops via output_to_node for
Transpose/Reshape, Softmax, or names containing tokens like "q", "k", "v",
"attn", "score", "softmax") or detect the Q@K^T pattern by verifying one input
path comes from a Transpose of a Q-like tensor and the other from K-like tensor;
for attn@V detect the MatMul consuming Softmax output and a V-like source.
Update the loop that currently inspects node.op_type == "MatMul" and uses
inputs_from_dq to call this helper and only count/print when both QDQ and
attention pattern match.
- Around line 225-230: The export is mutating the live model because
model.float() is in-place, which alters base_model/quantized_model used later;
fix by exporting from a detached copy instead (e.g., create a deep copy of model
with copy.deepcopy(model) and call .float() or .to(torch.float16) on that copy)
so get_onnx_bytes_and_metadata receives a non-mutated model; ensure you import
copy and use the copied instance when calling get_onnx_bytes_and_metadata to
avoid changing base_model/quantized_model before accuracy evaluation.
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 100-108: The code currently uses any(c.op == "MatMul" for c in
candidate.outputs[0].outputs) and then rewires/clears all consumers which breaks
non-MatMul branches; change the logic to require all(c.op == "MatMul" for c in
candidate.outputs[0].outputs) before performing the global rewrite OR,
preferably, only rewrite the specific MatMul edges: iterate
candidate.outputs[0].outputs, for each consumer c with c.op == "MatMul" rewire
that consumer's input to use the transposed/scaled/quantized tensor and leave
other consumers untouched, and do not clear original outputs (update
transpose_to_remove only when all downstream edges have been safely redirected).
Apply the same fix pattern to the other rewrite sites that manipulate
torch_weights, perm, transpose_to_remove, and similar MatMul-aware transforms.
In `@modelopt/onnx/utils.py`:
- Around line 1422-1505: The fold helpers unconditionally convert Q/DQ scale
initializers to FLOAT16 which is invalid for opsets < BASE_MIN_OPSET; update
_scale_fp32_to_fp16, fold_dq_fp32_to_fp16_casts and fold_q_fp16_to_fp32_casts to
guard the mutation by checking get_opset_version(onnx_model) (or the model
passed in) and only perform the FP32→FP16 rewrite when
get_opset_version(onnx_model) >= BASE_MIN_OPSET; if the check fails, skip
mutating initializers and skip folding the cast nodes (i.e., return the model
unchanged or continue without calling _scale_fp32_to_fp16/_bypass_cast_node),
using the existing function names (_scale_fp32_to_fp16,
fold_dq_fp32_to_fp16_casts, fold_q_fp16_to_fp32_casts) and constants
(BASE_MIN_OPSET) to locate where to add the guard.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 005378aa-8fac-4f2d-98a1-55297415cbe3
📒 Files selected for processing (8)
examples/torch_onnx/vit_mha_quantization.pymodelopt/onnx/export/fp8_exporter.pymodelopt/onnx/utils.pymodelopt/torch/_deploy/utils/torch_onnx.pymodelopt/torch/quantization/export_onnx.pymodelopt/torch/quantization/nn/__init__.pymodelopt/torch/quantization/nn/modules/quant_layernorm.pymodelopt/torch/quantization/plugins/huggingface.py
cjluo-nv
left a comment
There was a problem hiding this comment.
This PR adds FP8 MHA quantization support for HuggingFace ViT models with ONNX export optimizations. The implementation is well-structured and addresses a real gap (NVBug 6078291). However, there are several issues to address:
Critical issues:
-
No unit tests — This is ~933 lines of new/changed library code across core export paths (
fp8_exporter.py,utils.py,export_onnx.py,huggingface.py) with zero unit tests. The only "test" is the example script which requires GPU, ImageNet data, and TRT. The graph rewrite functions infp8_exporter.py, the cast folding helpers inutils.py, the attention skipping logic inhuggingface.py, and the LayerNorm quantization registration all need unit tests. -
Bare
assertfor runtime validation infp8_exporter.py— the existingasserton QDQ pair validation will be stripped with-O. -
Silent
contextlib.suppress(Exception)in the example — can mask real failures during benchmark parsing.
Minor issues:
4. The _scale_fp32_to_fp16 helper doesn't handle the case where the scale value overflows or underflows to inf/0 in FP16 — this could silently produce bad quantization results for extreme scales.
-
The
_move_mul_before_qdqrewrite assumes a single scalar const Mul for attention scaling; if the model architecture changes, these pattern-matching rewrites could silently become no-ops without any warning. -
The
_insert_qdq_after_softmaxhardcodes scale=1/448 which is correct for E4M3 but should at minimum document why this specific value and that it's tied to the FP8 E4M3 max representable value.
Positive aspects:
- Clean separation of graph rewrites as static methods
- Good docstrings on the new functions
- The parent_attention_types detection for avoiding double-patching is well done
- The LayerNorm registration follows existing patterns exactly
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1289 +/- ##
==========================================
+ Coverage 74.60% 75.73% +1.12%
==========================================
Files 467 468 +1
Lines 50176 50374 +198
==========================================
+ Hits 37435 38151 +716
+ Misses 12741 12223 -518
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
d6533ac to
c436553
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/onnx/export/fp8_exporter.py (1)
79-81: Replaceassertwith explicit exception for runtime validation.Per codebase conventions,
assertstatements are stripped with-Oflag. Useraise RuntimeError(...)for runtime validation that must always execute.Suggested fix
- assert dq_op.op == "TRT_FP8DequantizeLinear", ( - f"QDQ does not occur in pairs. You reached {dq_op.op}" - ) + if dq_op.op != "TRT_FP8DequantizeLinear": + raise RuntimeError(f"QDQ does not occur in pairs. You reached {dq_op.op}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/export/fp8_exporter.py` around lines 79 - 81, The assertion using assert dq_op.op == "TRT_FP8DequantizeLinear" in fp8_exporter.py must be replaced with an explicit runtime check that always runs: check the condition on dq_op.op and if it fails raise a RuntimeError with the same descriptive message (e.g., f"QDQ does not occur in pairs. You reached {dq_op.op}"); update the code around the existing dq_op.op check rather than using assert so the validation remains active under optimized runs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/torch_onnx/vit_mha_quantization.py`:
- Around line 428-429: Add a boolean CLI flag (e.g., --trust_remote_code)
defaulting to False to the argument parser and expose it as
args.trust_remote_code, then pass that value into the model/component loading
calls (replace ViTImageProcessor.from_pretrained(args.model_name) and
ViTForImageClassification.from_pretrained(args.model_name) with calls that
include trust_remote_code=args.trust_remote_code) so callers can opt-in to
remote code execution while keeping the default safe.
---
Nitpick comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 79-81: The assertion using assert dq_op.op ==
"TRT_FP8DequantizeLinear" in fp8_exporter.py must be replaced with an explicit
runtime check that always runs: check the condition on dq_op.op and if it fails
raise a RuntimeError with the same descriptive message (e.g., f"QDQ does not
occur in pairs. You reached {dq_op.op}"); update the code around the existing
dq_op.op check rather than using assert so the validation remains active under
optimized runs.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 84a8283a-4d75-4771-9733-e78e49eaf910
📒 Files selected for processing (8)
examples/torch_onnx/vit_mha_quantization.pymodelopt/onnx/export/fp8_exporter.pymodelopt/onnx/utils.pymodelopt/torch/_deploy/utils/torch_onnx.pymodelopt/torch/quantization/export_onnx.pymodelopt/torch/quantization/nn/__init__.pymodelopt/torch/quantization/nn/modules/quant_layernorm.pymodelopt/torch/quantization/plugins/huggingface.py
✅ Files skipped from review due to trivial changes (2)
- modelopt/torch/quantization/nn/modules/quant_layernorm.py
- modelopt/torch/quantization/plugins/huggingface.py
🚧 Files skipped from review as they are similar to previous changes (4)
- modelopt/torch/quantization/nn/init.py
- modelopt/torch/quantization/export_onnx.py
- modelopt/torch/_deploy/utils/torch_onnx.py
- modelopt/onnx/utils.py
| processor = ViTImageProcessor.from_pretrained(args.model_name) | ||
| base_model = ViTForImageClassification.from_pretrained(args.model_name).eval().to(device) |
There was a problem hiding this comment.
Expose trust_remote_code as a CLI parameter defaulting to False.
The --model_name argument allows users to specify any HuggingFace model. Some models require trust_remote_code=True, which enables execution of arbitrary Python shipped with the checkpoint. Per coding guidelines, this should be a caller-configurable parameter defaulting to False.
Suggested fix
parser.add_argument("--skip_onnx_ptq", action="store_true", help="Skip ONNX PTQ path")
+ parser.add_argument(
+ "--trust_remote_code",
+ action="store_true",
+ help="Trust remote code when loading HuggingFace models (security risk)",
+ )
args = parser.parse_args()Then update the loading calls:
- processor = ViTImageProcessor.from_pretrained(args.model_name)
- base_model = ViTForImageClassification.from_pretrained(args.model_name).eval().to(device)
+ processor = ViTImageProcessor.from_pretrained(
+ args.model_name, trust_remote_code=args.trust_remote_code
+ )
+ base_model = ViTForImageClassification.from_pretrained(
+ args.model_name, trust_remote_code=args.trust_remote_code
+ ).eval().to(device)As per coding guidelines: "Do not hardcode trust_remote_code=True when loading Hugging Face Transformers models. Let the caller decide via a parameter; default to False."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/torch_onnx/vit_mha_quantization.py` around lines 428 - 429, Add a
boolean CLI flag (e.g., --trust_remote_code) defaulting to False to the argument
parser and expose it as args.trust_remote_code, then pass that value into the
model/component loading calls (replace
ViTImageProcessor.from_pretrained(args.model_name) and
ViTForImageClassification.from_pretrained(args.model_name) with calls that
include trust_remote_code=args.trust_remote_code) so callers can opt-in to
remote code execution while keeping the default safe.
c436553 to
9bfcb72
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (2)
modelopt/onnx/export/fp8_exporter.py (2)
428-434: Unreachable defensive check.The check
if consumer is q_node: continueat line 429 is unreachable becauseq_nodeis created at line 413-419 afterconsumerswas captured at line 388. Theq_nodecannot be in theconsumerslist.This is harmless but adds dead code that could confuse future readers.
♻️ Suggested removal
for consumer in consumers: - if consumer is q_node: - continue for i, inp in enumerate(consumer.inputs): if inp is softmax_output: consumer.inputs[i] = dq_output🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/export/fp8_exporter.py` around lines 428 - 434, The loop contains an unreachable defensive check "if consumer is q_node: continue" because q_node is created after consumers was captured, so remove that check to avoid dead code; in the block that iterates over consumers (the for consumer in consumers: loop that replaces softmax_output with dq_output in consumer.inputs), delete the "if consumer is q_node: continue" line and leave the replacement logic intact (references: q_node, consumers, softmax_output, dq_output, count).
30-33: Consider deriving_FP8_E4M3_MAXfromtorch.finfofor consistency.The hardcoded value 448.0 is correct, but
modelopt/torch/quantization/qtensor/mxfp8_tensor.pyusestorch.finfo(torch.float8_e4m3fn).maxto obtain this value programmatically. Using the same pattern here would be more robust and self-documenting.♻️ Suggested change
-# FP8 E4M3 max representable magnitude; softmax output in [0, 1] saturates exactly at 1.0 -# when using 1/448 as the Q scale. -_FP8_E4M3_MAX = 448.0 +# FP8 E4M3 max representable magnitude; softmax output in [0, 1] saturates exactly at 1.0 +# when using 1/448 as the Q scale. +_FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 _FP8_E4M3_SOFTMAX_SCALE = 1.0 / _FP8_E4M3_MAX🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/export/fp8_exporter.py` around lines 30 - 33, The constant _FP8_E4M3_MAX is hardcoded but should be derived from torch.finfo for consistency with mxfp8_tensor; replace the literal 448.0 with torch.finfo(torch.float8_e4m3fn).max and recompute _FP8_E4M3_SOFTMAX_SCALE as 1.0 / _FP8_E4M3_MAX; ensure torch is imported in this module and preserve the existing constant names (_FP8_E4M3_MAX and _FP8_E4M3_SOFTMAX_SCALE) so other references remain valid.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 428-434: The loop contains an unreachable defensive check "if
consumer is q_node: continue" because q_node is created after consumers was
captured, so remove that check to avoid dead code; in the block that iterates
over consumers (the for consumer in consumers: loop that replaces softmax_output
with dq_output in consumer.inputs), delete the "if consumer is q_node: continue"
line and leave the replacement logic intact (references: q_node, consumers,
softmax_output, dq_output, count).
- Around line 30-33: The constant _FP8_E4M3_MAX is hardcoded but should be
derived from torch.finfo for consistency with mxfp8_tensor; replace the literal
448.0 with torch.finfo(torch.float8_e4m3fn).max and recompute
_FP8_E4M3_SOFTMAX_SCALE as 1.0 / _FP8_E4M3_MAX; ensure torch is imported in this
module and preserve the existing constant names (_FP8_E4M3_MAX and
_FP8_E4M3_SOFTMAX_SCALE) so other references remain valid.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: ec567acc-14ac-471b-b217-d45efe89af22
📒 Files selected for processing (9)
CHANGELOG.rstexamples/torch_onnx/vit_mha_quantization.pymodelopt/onnx/export/fp8_exporter.pymodelopt/onnx/utils.pymodelopt/torch/_deploy/utils/torch_onnx.pymodelopt/torch/quantization/export_onnx.pymodelopt/torch/quantization/nn/__init__.pymodelopt/torch/quantization/nn/modules/quant_layernorm.pymodelopt/torch/quantization/plugins/huggingface.py
✅ Files skipped from review due to trivial changes (2)
- CHANGELOG.rst
- modelopt/torch/quantization/nn/modules/quant_layernorm.py
🚧 Files skipped from review as they are similar to previous changes (6)
- modelopt/torch/quantization/nn/init.py
- modelopt/torch/_deploy/utils/torch_onnx.py
- modelopt/torch/quantization/plugins/huggingface.py
- modelopt/torch/quantization/export_onnx.py
- examples/torch_onnx/vit_mha_quantization.py
- modelopt/onnx/utils.py
9bfcb72 to
ef8c769
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
modelopt/onnx/export/fp8_exporter.py (1)
391-392: Remove unreachable QuantizeLinear check in softmax rewrite.After Line 389 (
all(c.op == "MatMul" for c in consumers)), Line 391 cannot be true in this code path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/export/fp8_exporter.py` around lines 391 - 392, In the softmax rewrite inside fp8_exporter.py remove the redundant unreachable check that tests for any(c.op == "QuantizeLinear" for c in consumers) after the preceding all(c.op == "MatMul" for c in consumers) guard; update the block around the softmax rewrite (look for the function/method handling the consumers list and the if statements using all(...) and any(...)) to delete the second check and its continue so the logic relies only on the MatMul consumer predicate.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 99-104: The code assumes a Cast node (identified as
cast_to_remove) can be fully cleared when a Transpose child is found, but that
disconnects other live consumers; change the removal logic so that after finding
candidate = Transpose you check cast_to_remove.outputs for other consumers
besides that Transpose and only remove the specific edge/consumer (or skip
removing the Cast entirely) instead of calling a blanket clear on
cast_to_remove.outputs; locate and update the removal at the code that clears
Cast outputs (the logic referencing cast_to_remove and candidate/Transpose later
around where outputs are cleared) to remove only the Transpose consumer (e.g.,
remove that output link or reroute it) while leaving other consumers intact.
- Around line 271-279: The rewrite moves Mul/Transpose across
Quantize/Dequantize but fails to guard upstream fanout; add single-consumer
checks before mutating upstream nodes: for the variables and nodes referenced
(dq_node, q_output, q_node, q_input, and the DequantizeLinear/QuantizeLinear
pairs) ensure the upstream variable that will be rewritten has outputs length ==
1 (i.e., only consumed by the DQ/transpose path) and that the DQ/transpose node
itself does not have other consumers, and only then perform the q_node.inputs[0]
mutation; apply the same single-consumer guard to the other similar blocks (the
code around the other uses of dq_node/q_node/q_output at the locations noted) so
unrelated branches aren’t modified.
---
Nitpick comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 391-392: In the softmax rewrite inside fp8_exporter.py remove the
redundant unreachable check that tests for any(c.op == "QuantizeLinear" for c in
consumers) after the preceding all(c.op == "MatMul" for c in consumers) guard;
update the block around the softmax rewrite (look for the function/method
handling the consumers list and the if statements using all(...) and any(...))
to delete the second check and its continue so the logic relies only on the
MatMul consumer predicate.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 8107c4b4-ca4e-4406-b813-c8988638709b
📒 Files selected for processing (8)
CHANGELOG.rstmodelopt/onnx/export/fp8_exporter.pymodelopt/onnx/utils.pymodelopt/torch/_deploy/utils/torch_onnx.pymodelopt/torch/quantization/export_onnx.pymodelopt/torch/quantization/nn/__init__.pymodelopt/torch/quantization/nn/modules/quant_layernorm.pymodelopt/torch/quantization/plugins/huggingface.py
✅ Files skipped from review due to trivial changes (1)
- modelopt/torch/quantization/nn/modules/quant_layernorm.py
🚧 Files skipped from review as they are similar to previous changes (6)
- modelopt/torch/quantization/nn/init.py
- modelopt/torch/_deploy/utils/torch_onnx.py
- modelopt/torch/quantization/plugins/huggingface.py
- modelopt/onnx/utils.py
- modelopt/torch/quantization/export_onnx.py
- CHANGELOG.rst
ef8c769 to
928d417
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
examples/torch_onnx/torch_quant_to_onnx.py (1)
186-215: Consider adding a fallback forself.attn_dimto support older timm versions.Line 210 uses
self.attn_dim, which is a relatively recent addition to timm's Attention class (circa late 2025). Sincepyproject.tomldoes not pin a specific timm version, older releases may lack this attribute. If broad timm compatibility is intended, use:getattr(self, 'attn_dim', self.num_heads * self.head_dim)to gracefully handle versions that compute this value dynamically instead of exposing it as an attribute.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 186 - 215, In _vit_attention_forward replace direct use of self.attn_dim with a safe fallback so older timm versions that lack the attribute don't break: compute attn_dim = getattr(self, "attn_dim", self.num_heads * self.head_dim) and use that variable when reshaping (and anywhere else self.attn_dim is referenced) so the method works whether the attribute exists or must be derived from num_heads and head_dim.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@examples/torch_onnx/torch_quant_to_onnx.py`:
- Around line 186-215: In _vit_attention_forward replace direct use of
self.attn_dim with a safe fallback so older timm versions that lack the
attribute don't break: compute attn_dim = getattr(self, "attn_dim",
self.num_heads * self.head_dim) and use that variable when reshaping (and
anywhere else self.attn_dim is referenced) so the method works whether the
attribute exists or must be derived from num_heads and head_dim.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 08cceb9c-1496-4b86-8b2f-60f5dbdbd74c
📒 Files selected for processing (10)
CHANGELOG.rstexamples/torch_onnx/torch_quant_to_onnx.pymodelopt/onnx/export/fp8_exporter.pymodelopt/onnx/utils.pymodelopt/torch/_deploy/utils/torch_onnx.pymodelopt/torch/quantization/export_onnx.pymodelopt/torch/quantization/nn/__init__.pymodelopt/torch/quantization/nn/modules/quant_layernorm.pymodelopt/torch/quantization/nn/modules/quant_softmax.pymodelopt/torch/quantization/plugins/huggingface.py
✅ Files skipped from review due to trivial changes (2)
- CHANGELOG.rst
- modelopt/torch/quantization/export_onnx.py
🚧 Files skipped from review as they are similar to previous changes (4)
- modelopt/torch/_deploy/utils/torch_onnx.py
- modelopt/torch/quantization/nn/modules/quant_layernorm.py
- modelopt/torch/quantization/plugins/huggingface.py
- modelopt/torch/quantization/nn/init.py
48d8486 to
ce9165e
Compare
f6f62a3 to
9af7e18
Compare
cjluo-nv
left a comment
There was a problem hiding this comment.
All critical issues from previous reviews have been addressed: unit tests added (4 test files), bare assert replaced with RuntimeError, all graph rewrites guarded with all(MatMul) fanout checks, opset guards added to FP16 scale folding, FP16 overflow warning added, per-instance nested attention detection implemented, and unused parameter documented. The remaining minor items (unreachable dead code in _insert_qdq_after_softmax, copyright year 2024 on new files) are non-blocking.
Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar transformer vision models) when exported to ONNX with FP8 Q/DQ. - fp8_exporter: rewrite attention-scaling Mul and K Transpose to the Q-side so DQ feeds MatMul directly, pre-transpose weight constants, insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype now matches the graph's float dtype to keep strongly-typed builds consistent. - onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16 inserts around Q/DQ by rewriting scale initializers to FP16, so TRT fuses DQ into the downstream GEMM/MatMul kernel. - torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native input dtype so no Cast is injected between graph and Q/DQ. - torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry so LayerNorm output quantizers are honored. - torch/quantization/plugins/huggingface: skip attention wrappers whose children are also "*Attention" to avoid double-patching eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention). Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8 config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer, disabled input quantizers on LayerNorm-followed layers, and *_bmm_quantizer entries) plus accuracy + TRT-latency comparison against an FP16 baseline. Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1): - Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs 80.96% / 95.44% (torch FP8) — -0.20% / -0.06% - TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Two independent bugs surfaced by the parametrized matrix in tests/examples/torch_onnx/test_torch_quant_to_onnx.py: - MXFP8/NVFP4 lower input quantizers to TRT DynamicQuantize, which only supports 2D/3D input. Swin/SwinV2 keep the 4D (B, H, W, C) layout on per-block norm1, downsample.norm, and the top-level norm, causing trtexec (MXFP8) and the NVFP4 autocast TRT-shape-inference pre-pass to reject the graph. Added _disable_high_rank_input_quantizers which runs a forward-pass rank probe and disables quantizers on 4D+ inputs; gated on mxfp8 / nvfp4 / auto so FP8 and INT8 still quantize those layers (their Q/DQ has no rank constraint). Name-based alternatives would false-positive on ViT, whose same-named top-level norm is 3D. - swinv2_tiny-fp8 hit ZeroDivisionError in export_fp8 (448 / amax): timm's res-post-norm scheme zero-inits each SwinV2 block's norm1 / norm2 weight and bias, so under --no_pretrained those LayerNorm outputs are exactly zero, and the FP8 MHA override's output_quantizer calibrates to amax == 0. Added _disable_dead_quantizers to drop any quantizer whose calibrated amax is NaN or <= 0 before export. Full matrix (4 models x 5 modes) now passes: 20/20 in ~33 min. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
The test module imports from modelopt.torch.quantization.plugins.huggingface,
which imports transformers at module scope. Under the partial-install (torch)
CI job — which installs only torch, without transformers/onnx/diffusers —
collection failed with ModuleNotFoundError, taking the whole unit-torch
partial-install step down.
Add pytest.importorskip("transformers") before the plugin import, matching
the pattern used by the sibling test_fused_experts.py.
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
CI (Blackwell, compute capability 12.0) fails TRT engine build for resnet50 under fp8 / mxfp8 / nvfp4 / auto: Error Code 10: Could not find any implementation for node /conv1/input_quantizer/TRT_FP8QuantizeLinear ... [ElementWise] The node is ResNet50's top-level conv1 (7x7 stride-2, in_channels=3). TRT's Blackwell tactics for FP8 Q -> Conv fusion don't cover the raw-RGB (3-channel) first-layer pattern. Ada (compute capability 8.9, the local dev GPU) happens to have a tactic, which is why the matrix passed locally. Swin/ViT avoid this because their first conv (patch_embed.proj, also 3-channel) is already excluded by filter_func's patch_embed pattern. ResNet50's conv1 wasn't on any list. Add _disable_low_channel_conv_input_quantizers to disable the input_quantizer on any Conv2d with in_channels <= 3 for FP8-family modes. Weight quantization is preserved. This also aligns with standard quantization practice (leave first/last layers in higher precision). INT8 is unchanged - INT8 Q/DQ has broader TRT kernel coverage on Blackwell and built successfully in CI. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
6537f37 to
a9f87bf
Compare
Summary
Enables TensorRT attention-v2 fusion for vision transformers when exported to ONNX with FP8 Q/DQ. The core library changes are architecture-agnostic (drop-in for any FP8 ONNX export); coverage is exercised by the existing
examples/torch_onnx/torch_quant_to_onnx.pypipeline.modelopt/onnx/export/fp8_exporter.py— new post-processing passes: move attention-scalingMuland KTransposeto the Q-side so DQ feeds MatMul directly, pre-transpose constant weights, and insert FP8 Q/DQ on Softmax outputs (fixed1/448scale, data-independent) for MHA-v2 fusion. Rewrites only fire when every downstream consumer is a MatMul so non-attention branches are never perturbed.modelopt/onnx/utils.py—fold_dq_fp32_to_fp16_casts/fold_q_fp16_to_fp32_castsremove the Cast nodesconvert_float_to_float16inserts around Q/DQ and rewrite scale initializers to FP16 so TRT fuses DQ into the downstream GEMM. Guarded behind opset >= 19 (FP16 Q/DQ scale requirement). Warns on FP16 overflow/underflow.modelopt/torch/_deploy/utils/torch_onnx.py— calls the fold helpers for FP8-quantized models afterconvert_float_to_float16.modelopt/torch/quantization/export_onnx.py— keeps FP8 Q/DQ scale in the native input dtype so no Cast is emitted between graph and Q/DQ. Removes the now-unusedtrt_high_precision_dtypeparameter from_fp8_quantize/_fp8_dequantize.modelopt/torch/quantization/nn/modules/quant_layernorm.py(new) — registersnn.LayerNorminQuantModuleRegistryso LayerNorm output quantizers are honored.modelopt/torch/quantization/plugins/huggingface.py— skips*Attentionwrappers whose children are also*Attentionper-instance (not per-class) to avoid double-patchingeager_attention_forward(e.g.ViTAttentionvsViTSelfAttention).examples/torch_onnx/torch_quant_to_onnx.py— adds a_FP8_MHA_OVERRIDEconfig block to FP8 mode that enables LayerNorm output quantizer + disables its input quantizer for TRT attention fusion.Benchmarks
ViT-base-patch16-224, RTX 6000 Ada, strongly-typed FP8 via
trtexec. Accuracy on 2 000 ImageNet-1k validation samples (streaming).Batch = 1 (latency-bound)
Batch = 64 (throughput-bound, realistic inference)
Top-1 accuracy stays within 0.30 pp of FP16; at batch=64 the Torch FP8 MHA path matches ONNX PTQ wall-time — attention is the bottleneck there and both paths achieve full FP8 attention fusion (36/36 attention MatMuls with QDQ in ViT-base).
Test plan