Skip to content

[2/3][Feat]: Offline DFlash training#1295

Merged
h-guo18 merged 7 commits intohaoguo/spec-file-reorgfrom
haoguo/dflash-offline
Apr 23, 2026
Merged

[2/3][Feat]: Offline DFlash training#1295
h-guo18 merged 7 commits intohaoguo/spec-file-reorgfrom
haoguo/dflash-offline

Conversation

@h-guo18
Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 commented Apr 19, 2026

What does this PR do?

Type of change: new feature

Part 2 of a 3-PR series splitting #1271:

Changes:

  • Add dflash_offline flag to DFlashConfig for training from pre-computed hidden states; deletes base model layers to save memory.
  • Add Pydantic validators on DFlashConfig:
    • _derive_dflash_offline — auto-derive dflash_offline from data_args.offline_data_path in validation context. Not user-configurable: any user-supplied value is overridden by the derived value.
    • _resolve_mask_token_id — auto-detect dflash_mask_token_id from tokenizer.mask_token_id.
    • _check_mask_token_id — fail fast if unset after resolution.
  • HFDFlashModel.modify(): select num_orig_hidden_layers when offline; pick _base_model_lm_head device when no base layers present; drop base-model layers module.
  • HFDFlashModel.forward(): add offline branch — consumes precomputed base_model_outputs via DFlashBaseModelOutput.from_offline_dict, and when dflash_self_logit_distillation is enabled with base_model_logits absent, recomputes logits from base_model_hidden_states via _base_model_lm_head. Raises a clear error from the non-training / pseudo_speculative_generate paths when dflash_offline=True, since base-model layers have been deleted.
  • DFlashBaseModelOutput dataclass in modeling_dflash.py (with from_offline_dict classmethod) to unify online/offline output shapes. aux_hidden_states is required in from_offline_dict so missing keys fail fast at the entry point rather than deeper in the forward.
  • examples/speculative_decoding/main.py: replace inline mask_token_id auto-detect with DFlashConfig.model_validate(dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}).

Silent bug fix — add_generation_templateadd_generation_prompt

The pre-refactor compute_hidden_states_hf.py passed add_generation_template=False to tokenizer.apply_chat_template. This kwarg does not exist on HF apply_chat_template and was being silently ignored, so the intended "don't append a generation prompt" behavior was never actually applied. The new tokenize_with_loss_mask helper in examples/speculative_decoding/collect_hidden_states/common.py uses the correct add_generation_prompt=False. This is a real behavior change for anyone re-dumping hidden states: trailing generation prompts that were previously appended to the tokenized sequences will no longer be included.

Testing

  • New tests:

    • tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py — CPU unit tests for convert path (online keeps base layers, offline deletes them; num_orig_hidden_layers drives target_layer_ids in offline mode) and DFlashConfig._derive_dflash_offline validator.
    • TestDFlashOfflineForwardGPU in tests/gpu/torch/speculative/plugins/test_hf_dflash.py — GPU forward smoke with precomputed base_model_outputs, plus the dflash_self_logit_distillation logit-recompute path.
  • training test:
    image image

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ — additive dflash_offline flag defaulting to False; validators fall through when context not provided.
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅ — see Testing section above.
  • Did you update Changelog?: ✅

TODO (follow-up)

  • Update examples/speculative_decoding/collect_hidden_states/compute_hidden_states_*.py to support DFlash offline data. Current scripts are Eagle-specific — they hardcode the [2, N/2, N-3] aux-layer selection and emit {input_ids, hidden_states, aux_hidden_states}. DFlash offline needs:
    • Aux layer indices driven by build_target_layer_ids(num_orig_hidden_layers, num_draft_layers) (or a configurable list), not the Eagle triplet.
    • base_model_hidden_states key (last-layer hidden) so DFlashBaseModelOutput.from_offline_dict + the dflash_self_logit_distillation recompute path can consume it.
    • Optional base_model_logits dump so offline training can skip the self-distillation logit recomputation when logits are available.

Additional Information

Base branch is #1296 (file reorg). Retarget to main once #1296 merges.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 19, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 19, 2026

Important

Review skipped

Auto reviews are disabled on base/target branches other than the default branch.

🗂️ Base branches to auto review (3)
  • main
  • release/.*
  • feature/.*

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: a8597fa1-a76e-4211-895c-8773afac8a96

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/dflash-offline

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 19, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1295/

Built to branch gh-pages at 2026-04-19 21:56 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 19, 2026

Codecov Report

❌ Patch coverage is 88.05970% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.73%. Comparing base (536ea48) to head (e063406).

Files with missing lines Patch % Lines
modelopt/torch/speculative/config.py 78.26% 5 Missing ⚠️
modelopt/torch/speculative/eagle/utils.py 60.00% 2 Missing ⚠️
modelopt/torch/speculative/plugins/hf_dflash.py 96.00% 1 Missing ⚠️
Additional details and impacted files
@@                     Coverage Diff                     @@
##           haoguo/spec-file-reorg    #1295       +/-   ##
===========================================================
+ Coverage                   60.34%   75.73%   +15.39%     
===========================================================
  Files                         470      471        +1     
  Lines                       50255    50375      +120     
===========================================================
+ Hits                        30325    38154     +7829     
+ Misses                      19930    12221     -7709     
Flag Coverage Δ
examples 41.03% <37.31%> (+8.53%) ⬆️
gpu 58.40% <77.61%> (+42.63%) ⬆️
regression 14.84% <65.67%> (+0.04%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18 h-guo18 force-pushed the haoguo/dflash-offline branch 2 times, most recently from 9e4eeb0 to f208109 Compare April 19, 2026 21:53
@h-guo18 h-guo18 changed the base branch from main to haoguo/spec-file-reorg April 19, 2026 21:54
@h-guo18 h-guo18 changed the title offline dflash [2/3][Feat]: Offline DFlash training Apr 19, 2026
@h-guo18 h-guo18 self-assigned this Apr 19, 2026
@h-guo18 h-guo18 force-pushed the haoguo/dflash-offline branch from f208109 to 178b191 Compare April 19, 2026 23:40
@h-guo18 h-guo18 marked this pull request as ready for review April 22, 2026 02:38
@h-guo18 h-guo18 requested a review from a team as a code owner April 22, 2026 02:38
@h-guo18 h-guo18 requested review from kevalmorabia97 and removed request for a team April 22, 2026 02:38
@h-guo18
Copy link
Copy Markdown
Contributor Author

h-guo18 commented Apr 22, 2026

Self-review follow-ups (not in this PR — will address separately so this one stays focused):

  • hf_dflash.py offline forward: the assert "base_model_outputs" in kwargs gets stripped under python -O. Should become a raise RuntimeError(...) with a descriptive message.
  • hf_dflash.py modify(): self._base_model._modules.pop("layers") works but is non-idiomatic; del self._base_model.layers goes through PyTorch's proper submodule deregistration.
  • hf_dflash.py offline path reads base_config.num_orig_hidden_layers directly and would raise a bare AttributeError if unset. hf_eagle.py:172 handles the same concern with getattr(self.config, "num_orig_hidden_layers", 0); worth unifying.

@h-guo18
Copy link
Copy Markdown
Contributor Author

h-guo18 commented Apr 22, 2026

AI-assisted self-review (Claude) — findings posted for transparency.

Fixed in this PR:

  • _derive_dflash_offline: getattr guards offline_data_path; field marked auto-derived, not user-configurable.
  • dflash_offline=True now raises a clear error on non-training forward() and pseudo_speculative_generate() — base layers are deleted, so these paths would otherwise crash deep inside HF.
  • DFlashBaseModelOutput.from_offline_dict: aux_hidden_states required (direct indexing). Same for base_model_hidden_states on the logit-recompute path.
  • GPU test: dropped the include_logits=True branch in _make_base_model_outputs — production EagleOfflineDataCollator never emits base_model_logits, so the branch was exercising unreachable code.
  • CHANGELOG.rst updated; PR description now notes the silent add_generation_templateadd_generation_prompt bug fix in compute_hidden_states_hf.py.

Deferred to follow-up (see earlier comment):

  • assert "base_model_outputs" in kwargsraise RuntimeError(...) (stripped under python -O).
  • _base_model._modules.pop("layers")del self._base_model.layers.
  • base_config.num_orig_hidden_layers direct access → getattr(..., 0) to match hf_eagle.py:172.
  • Plumb base_model_logits through EagleOfflineDataCollator so the logit-recompute branch becomes reachable (already in the TODO list).

Comment thread CHANGELOG.rst
Copy link
Copy Markdown
Collaborator

@ChenhanYu ChenhanYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review

CI test failure root cause

The test_unified_export_megatron failure (TypeError: '>' not supported between instances of 'NoneType' and 'int') is at megatron_eagle.py:526:

if self.config.parallel_draft_step > 1:

This is because PR #1296 (the [1/3] dependency) removed parallel_draft_step from eagle/default_config.py, but megatron_eagle.py still references it. The config attribute is now None, causing the comparison to fail. Fix: guard with getattr(self.config, "parallel_draft_step", 1) > 1 or restore the default in the config.

Missing tests/gpu_regression for DFlash offline

The PR adds GPU integration and CPU unit tests that verify the forward pass runs and returns a finite loss — good. But there's no regression test that verifies end-to-end training convergence.

Please add a tests/gpu_regression test for DFlash offline that:

  1. Dumps hidden states from the same synthetic dataset used by the existing online DFlash regression test
  2. Trains offline DFlash for a few steps
  3. Verifies loss decreases (or matches a golden threshold)

Without this, the offline training path could silently regress while all existing tests still pass.

@h-guo18 h-guo18 force-pushed the haoguo/spec-file-reorg branch from 9c1ed15 to 536ea48 Compare April 23, 2026 21:36
@h-guo18 h-guo18 requested review from a team as code owners April 23, 2026 21:36
@h-guo18 h-guo18 requested review from ChenhanYu and removed request for a team April 23, 2026 21:36
h-guo18 added 3 commits April 23, 2026 21:43
- Add `dflash_offline` config flag for training from pre-computed hidden states;
  deletes base model layers to save memory.
- Move `dflash_mask_token_id` auto-detection from `main.py` into `DFlashConfig`
  Pydantic validators; derive `dflash_offline` from `data_args.offline_data_path`.
- Add `DFlashBaseModelOutput.from_offline_dict` classmethod for consuming
  pre-computed hidden states in the forward path.

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
h-guo18 added 4 commits April 23, 2026 21:43
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 force-pushed the haoguo/dflash-offline branch from f7a27dc to e063406 Compare April 23, 2026 21:45
@h-guo18 h-guo18 merged commit e063406 into haoguo/spec-file-reorg Apr 23, 2026
25 checks passed
@h-guo18 h-guo18 deleted the haoguo/dflash-offline branch April 23, 2026 22:05
@h-guo18 h-guo18 restored the haoguo/dflash-offline branch April 23, 2026 22:09
@h-guo18
Copy link
Copy Markdown
Contributor Author

h-guo18 commented Apr 23, 2026

Accidently closed. Reopened in #1337

@h-guo18
Copy link
Copy Markdown
Contributor Author

h-guo18 commented Apr 23, 2026

Missing tests/gpu_regression for DFlash offline

The PR adds GPU integration and CPU unit tests that verify the forward pass runs and returns a finite loss — good. But there's no regression test that verifies end-to-end training convergence.

Please add a tests/gpu_regression test for DFlash offline that:

  1. Dumps hidden states from the same synthetic dataset used by the existing online DFlash regression test
  2. Trains offline DFlash for a few steps
  3. Verifies loss decreases (or matches a golden threshold)

Without this, the offline training path could silently regress while all existing tests still pass.

This PR is reopened in #1337

Comment 1 addressed in previous PR. Comment 2 addressed in this PR.Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants