Skip to content

[NVBug 6108145] Fix PTQ calibration and export for fused-experts MoE (Qwen3.5-MoE VLM)#1340

Open
meenchen wants to merge 2 commits intomainfrom
weimingc/fix_qwen36_moe_ptq
Open

[NVBug 6108145] Fix PTQ calibration and export for fused-experts MoE (Qwen3.5-MoE VLM)#1340
meenchen wants to merge 2 commits intomainfrom
weimingc/fix_qwen36_moe_ptq

Conversation

@meenchen
Copy link
Copy Markdown
Contributor

@meenchen meenchen commented Apr 24, 2026

What does this PR do?

Type of change: Bug fix

Fixes a 4-bug cascade that caused silent PTQ failure on Qwen3.5-MoE VLMs (Qwen3.6-35B-A3B): calibration
appeared to succeed but produced token-salad at inference. Root cause: HF's @use_experts_implementation
dispatches expert forward to torch._grouped_mm / torch.bmm, bypassing the F.linear hook that captures
activations — so gate_up_proj_input_quantizer / down_proj_input_quantizer never calibrated and no input_scale
tensors were emitted.

Changes:

  • examples/llm_ptq/hf_ptq.py — force config._experts_implementation = "eager" (recursing into text_config /
    vision_config / …) so per-expert F.linear calls are visible to the calibration hook.
  • modelopt/torch/quantization/conversion.py — normalize plural ModuleList quantizer names (weight_quantizers.N
    → weight_quantizer) before fnmatch, so wildcards like mlp.expertsweight_quantizer match fused-expert
    quantizers.
  • modelopt/torch/export/unified_export_hf.py — hoist the _QuantFusedExperts export branch above the
    get_quantization_format() gate so _export_fused_experts() runs even when the top-level format query returns
    QUANTIZATION_NONE (happens for experts-only recipes).
  • modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml — layerwise: false (VLM nested layer structure
    breaks the layerwise walker).

Usage

  python examples/llm_ptq/hf_ptq.py \
      --pyt_ckpt_path Qwen/Qwen3.6-35B-A3B \
      --qformat nvfp4 \
      --kv_cache_qformat fp8 \
      --calib_size 512 \
      --export_path Qwen3.6-35B-A3B-NVFP4

Testing

Testing

End-to-end PTQ → vLLM deploy → NEL eval on Qwen3.6-35B-A3B (256 experts × 40 layers, 35B params):

Hook-call diagnostic: 0 → 6720 per-expert F.linear calls during calibration after the fix; 0 → 30720
input_scale tensors emitted in the exported checkpoint.

FP8 fused-MoE path still produces gibberish — separate follow-up (vLLM per-expert weight_scale handling).

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Improved support for fused-expert modules in quantization pipelines with dedicated export paths.
    • Added plugin to optimize expert execution paths during calibration.
  • Bug Fixes

    • Fixed quantizer matching for expert-indexed modules with wildcard configurations.
    • Resolved calibration discovery for VLM models with specific layer structures.
  • Documentation

    • Updated quantization configuration with notes on layer structure handling.
  • Tests

    • Added comprehensive test coverage for fused-experts quantization workflow.

Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
@meenchen meenchen requested review from a team as code owners April 24, 2026 05:17
@meenchen meenchen requested review from cjluo-nv and realAsma April 24, 2026 05:17
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 24, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@meenchen meenchen self-assigned this Apr 24, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 24, 2026

📝 Walkthrough

Walkthrough

This pull request adds comprehensive support for HuggingFace fused-expert modules in PTQ quantization by introducing dedicated export paths, normalizing quantizer names for wildcard matching, forcing eager expert implementation, updating configuration, and extending test coverage.

Changes

Cohort / File(s) Summary
Fused-Experts Export Path
modelopt/torch/export/unified_export_hf.py
Added dedicated export logic that detects gate_up_proj_weight_quantizers on _QuantFusedExperts modules and invokes _export_fused_experts under fsdp2_aware_weight_update context, bypassing standard quantization-format detection.
Quantizer Name Normalization
modelopt/torch/quantization/conversion.py
Added normalization logic for quantizer names from _QuantFusedExperts ModuleList children to enable wildcard-based matching against both single quantizers and per-expert indexed quantizers (weight_quantizers.N, input_quantizers.N).
Eager Implementation Plugin
modelopt/torch/quantization/plugins/huggingface.py
New plugin force_eager_experts_impl_on_the_fly that mutates model config to force eager expert implementation, ensuring quantizer hooks are exercised during calibration and export instead of using grouped_mm/bmm backends.
Configuration Update
modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml
Disabled layerwise calibration (quantize.algorithm.layerwise from true to false) with documentation noting requirement for VLM models where decoder layers live under model.language_model.layers.
Test Coverage
tests/unit/torch/quantization/plugins/test_fused_experts.py
Extended test suite covering eager implementation config mutation (including nested configs) and end-to-end PTQ calibration verification for fused-experts modules with assertions on calibrated amax values.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 56.52% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately and specifically summarizes the main change: fixing PTQ calibration and export for fused-experts MoE models, with the concrete example of Qwen3.5-MoE VLM.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns detected in modified files; no unsafe deserialization, eval/exec, hardcoded trust_remote_code, or # nosec comments found.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch weimingc/fix_qwen36_moe_ptq

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 24, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1340/

Built to branch gh-pages at 2026-04-24 16:44 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@meenchen meenchen requested a review from sychen52 April 24, 2026 05:21
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 649-657: The elif branch that checks for the same attribute
gate_up_proj_weight_quantizers is dead code because the preceding if block
handles that case and continues; remove the unreachable elif block (the second
check for gate_up_proj_weight_quantizers and its body) so only the initial
handling using fsdp2_aware_weight_update and _export_fused_experts(sub_module,
dtype) remains, leaving no duplicate checks for gate_up_proj_weight_quantizers
in the loop.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 6c028304-1560-45eb-bcba-de04e7c03a20

📥 Commits

Reviewing files that changed from the base of the PR and between 5887410 and 7c6c132.

📒 Files selected for processing (4)
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/quantization/conversion.py
  • modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml

Comment on lines +649 to +657
# _QuantFusedExperts uses plural `gate_up_proj_weight_quantizers` (ModuleList),
# which get_quantization_format's singular-weight_quantizer check misses. Handle
# it explicitly before the format gate so fused-experts get split + quantized.
if hasattr(sub_module, "gate_up_proj_weight_quantizers"):
from modelopt.torch.export.moe_utils import _export_fused_experts

with fsdp2_aware_weight_update(model, sub_module, reshard=False):
_export_fused_experts(sub_module, dtype)
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify that lines 689-695 are now dead code by checking if they can ever be reached
# for modules with gate_up_proj_weight_quantizers

# Show both code blocks for comparison
echo "=== Early check (lines 649-657) ==="
sed -n '649,657p' modelopt/torch/export/unified_export_hf.py

echo ""
echo "=== Later duplicate check (lines 689-695) ==="
sed -n '689,695p' modelopt/torch/export/unified_export_hf.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1176


🏁 Script executed:

#!/bin/bash
# Show the full if-elif-else structure around lines 649-695 to understand control flow
sed -n '645,700p' modelopt/torch/export/unified_export_hf.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 3631


Remove unreachable dead code at lines 689-695.

The early check at lines 649–657 catches all modules with gate_up_proj_weight_quantizers and executes continue, which skips to the next loop iteration. The elif at line 689 checking the same condition will never be reached because:

  1. Any module matching the condition is handled at line 652 and exits via continue
  2. The elif can only be checked if the first if was false, meaning the module lacks gate_up_proj_weight_quantizers

Lines 689–695 should be removed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 649 - 657, The elif
branch that checks for the same attribute gate_up_proj_weight_quantizers is dead
code because the preceding if block handles that case and continues; remove the
unreachable elif block (the second check for gate_up_proj_weight_quantizers and
its body) so only the initial handling using fsdp2_aware_weight_update and
_export_fused_experts(sub_module, dtype) remains, leaving no duplicate checks
for gate_up_proj_weight_quantizers in the loop.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 24, 2026

Codecov Report

❌ Patch coverage is 54.54545% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.61%. Comparing base (0678136) to head (7c6c132).
⚠️ Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/export/unified_export_hf.py 0.00% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1340      +/-   ##
==========================================
- Coverage   74.46%   72.61%   -1.85%     
==========================================
  Files         464      481      +17     
  Lines       50089    52610    +2521     
==========================================
+ Hits        37300    38205     +905     
- Misses      12789    14405    +1616     
Flag Coverage Δ
unit 52.71% <54.54%> (+0.26%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
@meenchen meenchen added bug Something isn't working cherry-pick-0.44.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc labels Apr 24, 2026
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a well-structured bug fix with good test coverage. The 4-part fix (eager impl forcing, quantizer name normalization, export path hoisting, YAML layerwise change) is logically coherent and well-tested. However, there's a duplicate code path in unified_export_hf.py that should be cleaned up.

Design: Despite the complexity gate firing, this is a bug fix within existing systems, not an architectural change. No new abstractions are introduced.

Tests: Comprehensive — covers force_eager_experts_impl_on_the_fly edge cases, and an end-to-end calibration test that guards the full pipeline (name normalization → wildcard matching → amax collection). Good.

Issue: The new early-exit block in _process_quantized_modules makes the existing elif hasattr(sub_module, "gate_up_proj_weight_quantizers") block (deeper in the same function, inside the get_quantization_format() != QUANTIZATION_NONE guard) dead code. One of these should be removed.

from modelopt.torch.export.moe_utils import _export_fused_experts

with fsdp2_aware_weight_update(model, sub_module, reshard=False):
_export_fused_experts(sub_module, dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new early-exit block makes the existing elif hasattr(sub_module, "gate_up_proj_weight_quantizers") block at the end of this function (inside the get_quantization_format(sub_module) != QUANTIZATION_NONE branch, around line ~700 in the full file) dead code — the continue here means we never reach that path.

Please remove the old dead-code block to avoid confusion for future maintainers who might not realize both paths exist.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
modelopt/torch/quantization/plugins/huggingface.py (1)

1441-1449: Harden recursive config traversal with a cycle guard.

The recursive _force() walk can loop forever on cyclic config graphs. Add a visited-set guard to make this robust.

Proposed patch
     nested_cfg_attrs = ("text_config", "vision_config", "audio_config", "speech_config")
+    visited_cfg_ids = set()
 
     def _force(cfg):
         if cfg is None:
             return
+        cfg_id = id(cfg)
+        if cfg_id in visited_cfg_ids:
+            return
+        visited_cfg_ids.add(cfg_id)
         if hasattr(cfg, "_experts_implementation"):
             cfg._experts_implementation = "eager"
         for sub in nested_cfg_attrs:
             if hasattr(cfg, sub):
                 _force(getattr(cfg, sub))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 1441 - 1449,
The recursive helper _force can loop on cyclic config graphs; modify _force to
accept and maintain a visited set (e.g., of object ids) and skip recursing into
cfg instances already seen, so each cfg is processed only once. Specifically,
update the _force signature to take an optional visited set, add the current cfg
to visited (use id(cfg) or cfg itself), return early if already visited, and
keep the existing behavior of setting cfg._experts_implementation = "eager" and
iterating nested_cfg_attrs; ensure recursive calls pass the same visited set.
This change hardens _force and prevents infinite recursion on cycles while still
touching the same symbols (_force, nested_cfg_attrs,
cfg._experts_implementation).
tests/unit/torch/quantization/plugins/test_fused_experts.py (1)

385-441: Make registry cleanup exception-safe in the calibration test.

If an assertion fails before the last line, the temporary registry entry may leak into subsequent tests.

Proposed patch
     def test_calibration_populates_all_expert_quantizers(self):
         """After PTQ, every input/weight quantizer on the fused-experts module has amax set."""
         import modelopt.torch.quantization as mtq
 
         model = _TinyMoEModel()
         expert_type = type(model.moe.experts)
         self._cleanup_registry(expert_type)
-
-        quant_cfg = {
+        try:
+            quant_cfg = {
             "quant_cfg": [
                 {"quantizer_name": "*", "enable": False},
                 {
                     "quantizer_name": "*gate_up_proj_input_quantizer",
                     "cfg": {"num_bits": 8, "axis": None},
@@
         for idx in range(NUM_EXPERTS):
             assert experts.gate_up_proj_weight_quantizers[idx].amax is not None, (
                 f"gate_up_proj_weight_quantizers[{idx}].amax is None — "
                 "plural ModuleList name normalization in _match_quantizer likely broken."
             )
             assert experts.down_proj_weight_quantizers[idx].amax is not None, (
                 f"down_proj_weight_quantizers[{idx}].amax is None."
             )
-
-        self._cleanup_registry(expert_type)
+        finally:
+            self._cleanup_registry(expert_type)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/quantization/plugins/test_fused_experts.py` around lines 385
- 441, The test creates a temporary registry entry via expert_type and calls
self._cleanup_registry(expert_type) at the end but can leak if an assertion
fails; wrap the main test actions (quant_cfg setup, forward_loop, mtq.quantize,
and all asserts) in a try/finally and move the final
self._cleanup_registry(expert_type) into the finally block so cleanup always
runs; keep expert_type assigned before the try and leave the initial cleanup
call (before quantization) as-is.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/quantization/plugins/huggingface.py`:
- Around line 1441-1449: The recursive helper _force can loop on cyclic config
graphs; modify _force to accept and maintain a visited set (e.g., of object ids)
and skip recursing into cfg instances already seen, so each cfg is processed
only once. Specifically, update the _force signature to take an optional visited
set, add the current cfg to visited (use id(cfg) or cfg itself), return early if
already visited, and keep the existing behavior of setting
cfg._experts_implementation = "eager" and iterating nested_cfg_attrs; ensure
recursive calls pass the same visited set. This change hardens _force and
prevents infinite recursion on cycles while still touching the same symbols
(_force, nested_cfg_attrs, cfg._experts_implementation).

In `@tests/unit/torch/quantization/plugins/test_fused_experts.py`:
- Around line 385-441: The test creates a temporary registry entry via
expert_type and calls self._cleanup_registry(expert_type) at the end but can
leak if an assertion fails; wrap the main test actions (quant_cfg setup,
forward_loop, mtq.quantize, and all asserts) in a try/finally and move the final
self._cleanup_registry(expert_type) into the finally block so cleanup always
runs; keep expert_type assigned before the try and leave the initial cleanup
call (before quantization) as-is.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 310a1d2b-745c-4df5-8317-038bf82c199f

📥 Commits

Reviewing files that changed from the base of the PR and between 7c6c132 and 9414089.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/plugins/huggingface.py
  • tests/unit/torch/quantization/plugins/test_fused_experts.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working cherry-pick-0.44.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants