[2/3][Feat]: Offline DFlash training#1295
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. 🗂️ Base branches to auto review (3)
Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## haoguo/spec-file-reorg #1295 +/- ##
===========================================================
+ Coverage 60.34% 75.73% +15.39%
===========================================================
Files 470 471 +1
Lines 50255 50375 +120
===========================================================
+ Hits 30325 38154 +7829
+ Misses 19930 12221 -7709
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
9e4eeb0 to
f208109
Compare
f208109 to
178b191
Compare
|
Self-review follow-ups (not in this PR — will address separately so this one stays focused):
|
|
AI-assisted self-review (Claude) — findings posted for transparency. Fixed in this PR:
Deferred to follow-up (see earlier comment):
|
There was a problem hiding this comment.
Review
CI test failure root cause
The test_unified_export_megatron failure (TypeError: '>' not supported between instances of 'NoneType' and 'int') is at megatron_eagle.py:526:
if self.config.parallel_draft_step > 1:This is because PR #1296 (the [1/3] dependency) removed parallel_draft_step from eagle/default_config.py, but megatron_eagle.py still references it. The config attribute is now None, causing the comparison to fail. Fix: guard with getattr(self.config, "parallel_draft_step", 1) > 1 or restore the default in the config.
Missing tests/gpu_regression for DFlash offline
The PR adds GPU integration and CPU unit tests that verify the forward pass runs and returns a finite loss — good. But there's no regression test that verifies end-to-end training convergence.
Please add a tests/gpu_regression test for DFlash offline that:
- Dumps hidden states from the same synthetic dataset used by the existing online DFlash regression test
- Trains offline DFlash for a few steps
- Verifies loss decreases (or matches a golden threshold)
Without this, the offline training path could silently regress while all existing tests still pass.
9c1ed15 to
536ea48
Compare
- Add `dflash_offline` config flag for training from pre-computed hidden states; deletes base model layers to save memory. - Move `dflash_mask_token_id` auto-detection from `main.py` into `DFlashConfig` Pydantic validators; derive `dflash_offline` from `data_args.offline_data_path`. - Add `DFlashBaseModelOutput.from_offline_dict` classmethod for consuming pre-computed hidden states in the forward path. Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
f7a27dc to
e063406
Compare
|
Accidently closed. Reopened in #1337 |
This PR is reopened in #1337 Comment 1 addressed in previous PR. Comment 2 addressed in this PR.Thanks! |
What does this PR do?
Type of change: new feature
Part 2 of a 3-PR series splitting #1271:
ParallelDraftHFSpecDecMixinChanges:
dflash_offlineflag toDFlashConfigfor training from pre-computed hidden states; deletes base model layers to save memory.DFlashConfig:_derive_dflash_offline— auto-derivedflash_offlinefromdata_args.offline_data_pathin validation context. Not user-configurable: any user-supplied value is overridden by the derived value._resolve_mask_token_id— auto-detectdflash_mask_token_idfromtokenizer.mask_token_id._check_mask_token_id— fail fast if unset after resolution.HFDFlashModel.modify(): selectnum_orig_hidden_layerswhen offline; pick_base_model_lm_headdevice when no base layers present; drop base-modellayersmodule.HFDFlashModel.forward(): add offline branch — consumes precomputedbase_model_outputsviaDFlashBaseModelOutput.from_offline_dict, and whendflash_self_logit_distillationis enabled withbase_model_logitsabsent, recomputes logits frombase_model_hidden_statesvia_base_model_lm_head. Raises a clear error from the non-training /pseudo_speculative_generatepaths whendflash_offline=True, since base-model layers have been deleted.DFlashBaseModelOutputdataclass inmodeling_dflash.py(withfrom_offline_dictclassmethod) to unify online/offline output shapes.aux_hidden_statesis required infrom_offline_dictso missing keys fail fast at the entry point rather than deeper in the forward.examples/speculative_decoding/main.py: replace inlinemask_token_idauto-detect withDFlashConfig.model_validate(dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}).Silent bug fix —
add_generation_template→add_generation_promptThe pre-refactor
compute_hidden_states_hf.pypassedadd_generation_template=Falsetotokenizer.apply_chat_template. This kwarg does not exist on HFapply_chat_templateand was being silently ignored, so the intended "don't append a generation prompt" behavior was never actually applied. The newtokenize_with_loss_maskhelper inexamples/speculative_decoding/collect_hidden_states/common.pyuses the correctadd_generation_prompt=False. This is a real behavior change for anyone re-dumping hidden states: trailing generation prompts that were previously appended to the tokenized sequences will no longer be included.Testing
New tests:
tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py— CPU unit tests for convert path (online keeps base layers, offline deletes them;num_orig_hidden_layersdrivestarget_layer_idsin offline mode) andDFlashConfig._derive_dflash_offlinevalidator.TestDFlashOfflineForwardGPUintests/gpu/torch/speculative/plugins/test_hf_dflash.py— GPU forward smoke with precomputedbase_model_outputs, plus thedflash_self_logit_distillationlogit-recompute path.training test:

Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).dflash_offlineflag defaulting toFalse; validators fall through when context not provided.CONTRIBUTING.md: N/ATODO (follow-up)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_*.pyto support DFlash offline data. Current scripts are Eagle-specific — they hardcode the[2, N/2, N-3]aux-layer selection and emit{input_ids, hidden_states, aux_hidden_states}. DFlash offline needs:build_target_layer_ids(num_orig_hidden_layers, num_draft_layers)(or a configurable list), not the Eagle triplet.base_model_hidden_stateskey (last-layer hidden) soDFlashBaseModelOutput.from_offline_dict+ thedflash_self_logit_distillationrecompute path can consume it.base_model_logitsdump so offline training can skip the self-distillation logit recomputation when logits are available.Additional Information
Base branch is #1296 (file reorg). Retarget to
mainonce #1296 merges.