Skip to content

fix incomplete mapping of safetensors in generated puzzletron checkpoint#1330

Open
grzegorz-k-karch wants to merge 6 commits intomainfrom
gkarch/fix-incomplete-tensor-mapping
Open

fix incomplete mapping of safetensors in generated puzzletron checkpoint#1330
grzegorz-k-karch wants to merge 6 commits intomainfrom
gkarch/fix-incomplete-tensor-mapping

Conversation

@grzegorz-k-karch
Copy link
Copy Markdown
Contributor

@grzegorz-k-karch grzegorz-k-karch commented Apr 23, 2026

What does this PR do?

Type of change: ? Bug fix

Fixes https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/puzzletron/main.py where multi-GPU run caused only part of the file model.safetensors.index.json to be written to disk.

Usage

does not apply

Testing

Follow instructions, step 3 - run with --nproc_per_node 2

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: N/A
  • Did you update Changelog?: N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added a public checkpoint-saving entry that consolidates distributed sharded model shards into a single filesystem checkpoint; retains direct saving for single-process runs.
  • Refactor

    • Validation/evaluation tooling now uses the consolidated checkpoint-saving flow when persisting realized model checkpoints during runs.

Signed-off-by: Grzegorz Karch <gkarch@nvidia.com>
@grzegorz-k-karch grzegorz-k-karch requested a review from a team as a code owner April 23, 2026 09:23
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 23, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@grzegorz-k-karch grzegorz-k-karch marked this pull request as draft April 23, 2026 09:23
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 23, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a distributed-aware checkpoint saver save_checkpoint_from_shards that gathers CPU-backed per-rank state_dict shards to rank 0, merges them, and delegates disk writes to the existing saver. Updates the validation script to call this new function when saving realized model checkpoints.

Changes

Cohort / File(s) Summary
Distributed checkpoint util
modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
Adds exported save_checkpoint_from_shards(model: PreTrainedModel, checkpoint_dir: Path | str, descriptor: "ModelDescriptor") -> None. Converts per-rank model.state_dict() tensors to CPU, uses torch.distributed / modelopt.torch.utils.distributed to detect world size and tdist.gather_object to collect shards on rank 0, merges shards into a full state_dict, and calls existing _save_checkpoint on rank 0. Adds function to __all__.
Validation script
modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py
Replaces import and call of save_checkpoint with save_checkpoint_from_shards(...) inside validate_puzzle_solutions when args.save_models is enabled; tokenizer copying logic unchanged.

Sequence Diagram(s)

sequenceDiagram
    participant Worker as "Worker Rank i"
    participant TDist as "torch.distributed"
    participant Rank0 as "Rank 0 (Collector)"
    participant Saver as "_save_checkpoint (Disk)"
    Worker->>TDist: prepare CPU state_dict shard\ntdist.gather_object(shard)
    TDist->>Rank0: deliver gathered shard objects
    Rank0->>Rank0: merge shards -> full state_dict
    Rank0->>Saver: call _save_checkpoint(full state_dict, checkpoint_dir, descriptor)
    Saver->>Saver: write safetensors index and subblocks to disk
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title describes fixing incomplete safetensors mapping, which aligns with the PR objective of addressing incomplete model.safetensors.index.json writes in multi-GPU scenarios.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed Pull request introduces checkpoint-saving functionality without any of the six security anti-patterns specified in SECURITY.md.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch gkarch/fix-incomplete-tensor-mapping

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 23, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1330/

Built to branch gh-pages at 2026-04-23 23:24 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py (1)

175-177: ⚠️ Potential issue | 🔴 Critical

model can be undefined before checkpoint save.

When args.save_models=True, args.skip_validation=True, and realizable_as_symlinks=True, model is never initialized but is still used at Line 192, causing a runtime failure.

💡 Proposed fix
-        if (args.save_models and not realizable_as_symlinks) or (not args.skip_validation):
+        if args.save_models or (not args.skip_validation):
             model = replacement_library.load_model(layer_replacements)
             model_config = model.config

Also applies to: 192-192

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py`
around lines 175 - 177, The code fails to initialize model when args.save_models
is True but realizable_as_symlinks is True and args.skip_validation is True;
change the load condition so replacement_library.load_model(layer_replacements)
(and model_config = model.config) runs whenever you will save checkpoints or
will validate — i.e., use if args.save_models or (not args.skip_validation): to
ensure model is defined before it's later used (references: args.save_models,
args.skip_validation, realizable_as_symlinks, model,
replacement_library.load_model, layer_replacements, model_config).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py`:
- Around line 175-177: The code fails to initialize model when args.save_models
is True but realizable_as_symlinks is True and args.skip_validation is True;
change the load condition so replacement_library.load_model(layer_replacements)
(and model_config = model.config) runs whenever you will save checkpoints or
will validate — i.e., use if args.save_models or (not args.skip_validation): to
ensure model is defined before it's later used (references: args.save_models,
args.skip_validation, realizable_as_symlinks, model,
replacement_library.load_model, layer_replacements, model_config).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: f462a7b9-d95d-49ea-9e4b-2c3e673190a3

📥 Commits

Reviewing files that changed from the base of the PR and between 2564da7 and f74ab04.

📒 Files selected for processing (2)
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 23, 2026

Codecov Report

❌ Patch coverage is 21.05263% with 15 lines in your changes missing coverage. Please review.
✅ Project coverage is 59.19%. Comparing base (5887410) to head (f10fc53).

Files with missing lines Patch % Lines
...lopt/torch/puzzletron/tools/checkpoint_utils_hf.py 17.64% 14 Missing ⚠️
...n/tools/validate_puzzle_with_multi_replacements.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1330       +/-   ##
===========================================
- Coverage   74.67%   59.19%   -15.49%     
===========================================
  Files         468      479       +11     
  Lines       50369    51297      +928     
===========================================
- Hits        37615    30366     -7249     
- Misses      12754    20931     +8177     
Flag Coverage Δ
unit 52.69% <21.05%> (-0.03%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py Outdated
Comment thread modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py Outdated
Signed-off-by: Grzegorz Karch <gkarch@nvidia.com>
@grzegorz-k-karch grzegorz-k-karch marked this pull request as ready for review April 23, 2026 15:55
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 23, 2026

Note

Docstrings generation - SUCCESS
Generated docstrings for this pull request at #1331

coderabbitai Bot added a commit that referenced this pull request Apr 23, 2026
Docstrings generation was requested by @grzegorz-k-karch.

* #1330 (comment)

The following files were modified:

* `modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`
* `modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py`
Docstrings generation was requested by @grzegorz-k-karch.

*
#1330 (comment)

The following files were modified:

* `modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`
*
`modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py`

<details>
<summary>ℹ️ Note</summary><blockquote>

CodeRabbit cannot perform edits on its own pull requests yet.

</blockquote></details>

---------

Signed-off-by: Grzegorz Karch <gkarch@nvidia.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Grzegorz Karch <gkarch@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`:
- Around line 226-237: The master rank writes the checkpoint after
tdist.gather_object while non-master ranks return immediately, so add an
explicit torch.distributed.barrier() (or dist_utils.barrier() if provided)
immediately after the _save_checkpoint(...) call to ensure all ranks wait until
the file I/O completes; also initialize gathered as [None] * dist_utils.size()
instead of [{}] * dist_utils.size() and, when merging gathered into full_sd,
skip any None entries (e.g., if shard_sd is None continue) before calling
full_sd.update(shard_sd) to avoid shared-reference/None issues.
🪄 Autofix (Beta)

✅ Autofix completed


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 8296bba7-370d-4543-8234-9e64ba16d6c6

📥 Commits

Reviewing files that changed from the base of the PR and between 1e8e25a and 84d68a3.

📒 Files selected for processing (1)
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py

Comment thread modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
@kevalmorabia97
Copy link
Copy Markdown
Collaborator

/ok to test 84d68a3

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 23, 2026

Note

Autofix is a beta feature. Expect some limitations and changes as we gather feedback and continue to improve it.

Fixes Applied Successfully

Fixed 1 file(s) based on 1 unresolved review comment.

A stacked PR containing fixes has been created.

  • Stacked PR: #1339
  • Files modified:
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py

Time taken: 1m 59s

This stacked PR contains CodeRabbit auto-fixes for #1330.

**Files modified:**
- `modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: CodeRabbit <noreply@coderabbit.ai>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py (1)

231-234: Guard against silent key collisions during shard merge.

full_sd.update(shard_sd) will silently overwrite if two ranks emit the same key. Add an overlap check to fail fast.

Suggested hardening
             for shard_sd in gathered:
                 if shard_sd is None:
                     continue
+                overlap = full_sd.keys() & shard_sd.keys()
+                if overlap:
+                    raise ValueError(f"Duplicate parameter keys across shards: {sorted(overlap)[:8]}")
                 full_sd.update(shard_sd)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py` around lines 231 -
234, When merging shard dictionaries into full_sd (loop over gathered /
shard_sd), guard against silent key collisions by checking for overlapping keys
before calling full_sd.update(shard_sd); compute the intersection between
full_sd.keys() and shard_sd.keys(), and if non-empty raise an informative
exception (including the overlapping key names and the source shard identifier
if available) so the merge fails fast instead of silently overwriting existing
entries.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`:
- Around line 227-240: When rank 0 calls _save_checkpoint it must catch
exceptions and propagate a status to all ranks so other ranks aren't left
blocked in tdist.gather_object; wrap the _save_checkpoint call in try/except,
set a boolean or dict status (e.g., save_status = {"ok": True} or {"ok": False,
"err": str(e)}), then use tdist.gather_object to collect statuses from all ranks
(same collective used for local_sd gathering) before calling
dist_utils.barrier(); after gathering, check the collected statuses on every
rank and raise or exit on failure (so every rank fails cleanly). Ensure you
reference the existing locals (local_sd, gathered, tdist.gather_object,
_save_checkpoint, dist_utils.barrier, dist_utils.is_master) when implementing
this propagation.

---

Nitpick comments:
In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`:
- Around line 231-234: When merging shard dictionaries into full_sd (loop over
gathered / shard_sd), guard against silent key collisions by checking for
overlapping keys before calling full_sd.update(shard_sd); compute the
intersection between full_sd.keys() and shard_sd.keys(), and if non-empty raise
an informative exception (including the overlapping key names and the source
shard identifier if available) so the merge fails fast instead of silently
overwriting existing entries.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 2b1017c8-adcc-4c73-a554-4a8161b0f151

📥 Commits

Reviewing files that changed from the base of the PR and between 84d68a3 and ff2afe1.

📒 Files selected for processing (1)
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py

Comment on lines +227 to +240
if dist_utils.is_master():
gathered: list[dict] = [None] * dist_utils.size()
tdist.gather_object(local_sd, gathered, dst=0)
full_sd: dict[str, torch.Tensor] = {}
for shard_sd in gathered:
if shard_sd is None:
continue
full_sd.update(shard_sd)
_save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor)
else:
tdist.gather_object(local_sd, dst=0)
# Barrier ensures all ranks wait until file I/O completes before continuing
dist_utils.barrier()
else:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -name "checkpoint_utils_hf.py" -type f

Repository: NVIDIA/Model-Optimizer

Length of output: 122


🏁 Script executed:

fd "checkpoint_utils_hf.py"

Repository: NVIDIA/Model-Optimizer

Length of output: 120


🏁 Script executed:

find . -type f -name "*checkpoint*" | grep -i hf

Repository: NVIDIA/Model-Optimizer

Length of output: 295


🏁 Script executed:

ls -la modelopt/torch/puzzletron/tools/ 2>/dev/null || ls -la modelopt/ 2>/dev/null | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1094


🏁 Script executed:

find . -path "*/puzzletron/tools/*.py" 2>/dev/null | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 912


🏁 Script executed:

git ls-files | grep -i puzzletron | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1640


🏁 Script executed:

git ls-files | grep -E "modelopt.*\.py$" | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 1226


🏁 Script executed:

ls -la modelopt/ 2>/dev/null || echo "modelopt dir not found"

Repository: NVIDIA/Model-Optimizer

Length of output: 458


🏁 Script executed:

pwd && ls -la | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1310


🏁 Script executed:

rg "_save_checkpoint" -A 10 -B 10 2>/dev/null || echo "Pattern not found"

Repository: NVIDIA/Model-Optimizer

Length of output: 20909


🏁 Script executed:

sed -n '220,245p' modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 1323


🏁 Script executed:

sed -n '210,250p' modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 2192


🏁 Script executed:

wc -l modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py && sed -n '195,250p' modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2541


🏁 Script executed:

sed -n '244,350p' modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py

Repository: NVIDIA/Model-Optimizer

Length of output: 4550


🏁 Script executed:

sed -n '244,400p' modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py

Repository: NVIDIA/Model-Optimizer

Length of output: 6569


🏁 Script executed:

fd "dist_utils.py" -o -type f | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 182


🏁 Script executed:

find . -name "dist_utils.py" -type f

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

head -30 modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1124


🏁 Script executed:

grep -n "import.*dist_utils\|from.*dist_utils" modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py

Repository: NVIDIA/Model-Optimizer

Length of output: 122


🏁 Script executed:

rg "def barrier\(\)" modelopt/torch/puzzletron/ -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

find modelopt -name "dist_utils*" -o -name "*distributed*" | grep -v __pycache__

Repository: NVIDIA/Model-Optimizer

Length of output: 101


🏁 Script executed:

find modelopt/torch -type d -name "*dist*"

Repository: NVIDIA/Model-Optimizer

Length of output: 88


🏁 Script executed:

grep "dist_utils" modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 275


🏁 Script executed:

grep -B 20 "def save_checkpoint_from_shards" modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py | grep "^from\|^import"

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

head -50 modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py | grep -E "^from|^import"

Repository: NVIDIA/Model-Optimizer

Length of output: 826


🏁 Script executed:

sed -n '1,80p' modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2628


🏁 Script executed:

find modelopt -name "dist_utils.py" 2>/dev/null

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

grep -r "from.*import.*dist_utils" modelopt/torch/puzzletron/tools/ 2>/dev/null | head -3

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

python3 << 'EOF'
with open('modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py', 'r') as f:
    lines = f.readlines()
    # Find where dist_utils is used
    for i, line in enumerate(lines[200:250], start=201):
        if 'dist_utils' in line or 'barrier' in line or 'gather_object' in line:
            print(f"{i}: {line.rstrip()}")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 359


🏁 Script executed:

python3 << 'EOF'
with open('modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py', 'r') as f:
    lines = f.readlines()
    for i in range(223, 241):
        print(f"{i+1}: {lines[i].rstrip()}")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 947


Prevent worker hang when rank-0 save fails.

If _save_checkpoint throws on rank 0 (line 235), non-master ranks remain blocked at gather_object (line 237) waiting for the collective operation to complete, causing a distributed deadlock. Propagate save status across all ranks before the barrier to ensure clean failure.

Suggested fix
     if dist_utils.size() > 1:
+        save_err: str | None = None
         if dist_utils.is_master():
             gathered: list[dict] = [None] * dist_utils.size()
             tdist.gather_object(local_sd, gathered, dst=0)
             full_sd: dict[str, torch.Tensor] = {}
             for shard_sd in gathered:
                 if shard_sd is None:
                     continue
                 full_sd.update(shard_sd)
-            _save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor)
+            try:
+                _save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor)
+            except Exception as e:
+                save_err = repr(e)
         else:
             tdist.gather_object(local_sd, dst=0)
-        # Barrier ensures all ranks wait until file I/O completes before continuing
+        err_box = [save_err]
+        tdist.broadcast_object_list(err_box, src=0)
         dist_utils.barrier()
+        if err_box[0] is not None:
+            raise RuntimeError(f"Checkpoint save failed on rank 0: {err_box[0]}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if dist_utils.is_master():
gathered: list[dict] = [None] * dist_utils.size()
tdist.gather_object(local_sd, gathered, dst=0)
full_sd: dict[str, torch.Tensor] = {}
for shard_sd in gathered:
if shard_sd is None:
continue
full_sd.update(shard_sd)
_save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor)
else:
tdist.gather_object(local_sd, dst=0)
# Barrier ensures all ranks wait until file I/O completes before continuing
dist_utils.barrier()
else:
if dist_utils.size() > 1:
save_err: str | None = None
if dist_utils.is_master():
gathered: list[dict] = [None] * dist_utils.size()
tdist.gather_object(local_sd, gathered, dst=0)
full_sd: dict[str, torch.Tensor] = {}
for shard_sd in gathered:
if shard_sd is None:
continue
full_sd.update(shard_sd)
try:
_save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor)
except Exception as e:
save_err = repr(e)
else:
tdist.gather_object(local_sd, dst=0)
err_box = [save_err]
tdist.broadcast_object_list(err_box, src=0)
# Barrier ensures all ranks wait until file I/O completes before continuing
dist_utils.barrier()
if err_box[0] is not None:
raise RuntimeError(f"Checkpoint save failed on rank 0: {err_box[0]}")
else:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py` around lines 227 -
240, When rank 0 calls _save_checkpoint it must catch exceptions and propagate a
status to all ranks so other ranks aren't left blocked in tdist.gather_object;
wrap the _save_checkpoint call in try/except, set a boolean or dict status
(e.g., save_status = {"ok": True} or {"ok": False, "err": str(e)}), then use
tdist.gather_object to collect statuses from all ranks (same collective used for
local_sd gathering) before calling dist_utils.barrier(); after gathering, check
the collected statuses on every rank and raise or exit on failure (so every rank
fails cleanly). Ensure you reference the existing locals (local_sd, gathered,
tdist.gather_object, _save_checkpoint, dist_utils.barrier, dist_utils.is_master)
when implementing this propagation.

Signed-off-by: Grzegorz Karch <gkarch@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py (1)

227-240: ⚠️ Potential issue | 🔴 Critical

Handle rank-0 save exceptions before synchronization to prevent deadlock.

If _save_checkpoint(...) fails on Line 235, rank 0 never reaches Line 239, while non-master ranks can wait forever in dist_utils.barrier(). Broadcast rank-0 save status before the barrier and fail all ranks consistently.

Suggested fix
 def save_checkpoint_from_shards(
     model: PreTrainedModel, checkpoint_dir: Path | str, descriptor: "ModelDescriptor"
 ) -> None:
@@
     local_sd = {k: v.cpu() for k, v in model.state_dict().items()}
     if dist_utils.size() > 1:
+        save_error: str | None = None
         if dist_utils.is_master():
-            gathered: list[dict] = [None] * dist_utils.size()
+            gathered: list[dict[str, torch.Tensor] | None] = [None] * dist_utils.size()
             tdist.gather_object(local_sd, gathered, dst=0)
             full_sd: dict[str, torch.Tensor] = {}
             for shard_sd in gathered:
                 if shard_sd is None:
                     continue
                 full_sd.update(shard_sd)
-            _save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor)
+            try:
+                _save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor)
+            except Exception as exc:
+                save_error = repr(exc)
         else:
             tdist.gather_object(local_sd, dst=0)
+        err_box = [save_error]
+        tdist.broadcast_object_list(err_box, src=0)
         # Barrier ensures all ranks wait until file I/O completes before continuing
         dist_utils.barrier()
+        if err_box[0] is not None:
+            raise RuntimeError(f"Checkpoint save failed on rank 0: {err_box[0]}")
     else:
         _save_checkpoint(model.config, local_sd, checkpoint_dir, descriptor)
#!/bin/bash
# Verify whether distributed save path propagates rank-0 failures before barrier.
rg -n -C4 '\bdef save_checkpoint_from_shards\b|gather_object\(|_save_checkpoint\(|broadcast_object_list\(|barrier\(' modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py` around lines 227 -
240, Wrap the rank-0 call to _save_checkpoint(model.config, full_sd,
checkpoint_dir, descriptor) in a try/except that captures any exception and sets
a serializable status object (e.g., {"ok": False, "err": str(exc)}) which you
then broadcast to all ranks before calling dist_utils.barrier(); on success set
status {"ok": True}. Use the existing tdist.gather_object/gather flow (inside
dist_utils.is_master branch) to build full_sd and then broadcast the status
(e.g., via dist_utils.broadcast_object_list or tdist.broadcast_object_list) so
non-master ranks receive the save outcome before entering dist_utils.barrier();
non-master ranks should check the received status and raise/exit if {"ok":
False} to ensure consistent failure across ranks. Ensure you reference
_save_checkpoint, dist_utils.is_master, tdist.gather_object, and
dist_utils.barrier in your changes and keep the status object simple and
JSON-serializable.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`:
- Around line 227-240: Wrap the rank-0 call to _save_checkpoint(model.config,
full_sd, checkpoint_dir, descriptor) in a try/except that captures any exception
and sets a serializable status object (e.g., {"ok": False, "err": str(exc)})
which you then broadcast to all ranks before calling dist_utils.barrier(); on
success set status {"ok": True}. Use the existing tdist.gather_object/gather
flow (inside dist_utils.is_master branch) to build full_sd and then broadcast
the status (e.g., via dist_utils.broadcast_object_list or
tdist.broadcast_object_list) so non-master ranks receive the save outcome before
entering dist_utils.barrier(); non-master ranks should check the received status
and raise/exit if {"ok": False} to ensure consistent failure across ranks.
Ensure you reference _save_checkpoint, dist_utils.is_master,
tdist.gather_object, and dist_utils.barrier in your changes and keep the status
object simple and JSON-serializable.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 6e3f52c6-47dd-4026-82ce-a2331512edaa

📥 Commits

Reviewing files that changed from the base of the PR and between ff2afe1 and f10fc53.

📒 Files selected for processing (1)
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants