home / skills / orchestra-research / ai-research-skills / pytorch-fsdp2
This skill helps you integrate PyTorch FSDP2 into training scripts with correct initialization, sharding, mixed precision, and DTensor-based checkpointing.
npx playbooks add skill orchestra-research/ai-research-skills --skill pytorch-fsdp2Review the files below or copy the command above to add this skill to your agents.
---
name: pytorch-fsdp2
description: Adds PyTorch FSDP2 (fully_shard) to training scripts with correct init, sharding, mixed precision/offload config, and distributed checkpointing. Use when models exceed single-GPU memory or when you need DTensor-based sharding with DeviceMesh.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [PyTorch, FSDP2, Fully Sharded Data Parallel, Distributed Training, DTensor, Device Mesh, Sharded Checkpointing, Mixed Precision, Offload, Torch Distributed]
dependencies: [torch]
---
# Skill: Use PyTorch FSDP2 (`fully_shard`) correctly in a training script
This skill teaches a coding agent how to **add PyTorch FSDP2** to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.
> FSDP2 in PyTorch is exposed primarily via `torch.distributed.fsdp.fully_shard` and the `FSDPModule` methods it adds in-place to modules. See: `references/pytorch_fully_shard_api.md`, `references/pytorch_fsdp2_tutorial.md`.
---
## When to use this skill
Use FSDP2 when:
- Your model **doesn’t fit** on one GPU (parameters + gradients + optimizer state).
- You want an eager-mode sharding approach that is **DTensor-based per-parameter sharding** (more inspectable, simpler sharded state dicts) than FSDP1.
- You may later compose DP with **Tensor Parallel** using **DeviceMesh**.
Avoid (or be careful) if:
- You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
- You’re forced onto older PyTorch versions without the FSDP2 stack.
## Alternatives (when FSDP2 is not the best fit)
- **DistributedDataParallel (DDP)**: Use the standard data-parallel wrapper when you want classic distributed data parallel training.
- **FullyShardedDataParallel (FSDP1)**: Use the original FSDP wrapper for parameter sharding across data-parallel workers.
Reference: `references/pytorch_ddp_notes.md`, `references/pytorch_fsdp1_api.md`.
---
## Contract the agent must follow
1. **Launch with `torchrun`** and set the CUDA device per process (usually via `LOCAL_RANK`).
2. **Apply `fully_shard()` bottom-up**, i.e., shard submodules (e.g., Transformer blocks) before the root module.
3. **Call `model(input)`**, not `model.forward(input)`, so the FSDP2 hooks run (unless you explicitly `unshard()` or register the forward method).
4. **Create the optimizer after sharding** and make sure it is built on the **DTensor parameters** (post-`fully_shard`).
5. **Checkpoint using Distributed Checkpoint (DCP)** or the distributed-state-dict helpers, not naïve `torch.save(model.state_dict())` unless you deliberately gather to full tensors.
(Each of these rules is directly described in the official API docs/tutorial; see references.)
---
## Step-by-step procedure
### 0) Version & environment sanity
- Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
- Use `torchrun --nproc_per_node <gpus_per_node> ...` and ensure `RANK`, `WORLD_SIZE`, `LOCAL_RANK` are visible.
Reference: `references/pytorch_fsdp2_tutorial.md` (launch commands and setup), `references/pytorch_fully_shard_api.md` (user contract).
---
### 1) Initialize distributed and set device
Minimal, correct pattern:
- `dist.init_process_group(backend="nccl")`
- `torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))`
- Optionally create a `DeviceMesh` to describe the data-parallel group(s)
Reference: `references/pytorch_device_mesh_tutorial.md` (why DeviceMesh exists & how it manages process groups).
---
### 2) Build model on meta device (recommended for very large models)
For big models, initialize on `meta`, apply sharding, then materialize weights on GPU:
- `with torch.device("meta"): model = ...`
- apply `fully_shard(...)` on submodules, then `fully_shard(model)`
- `model.to_empty(device="cuda")`
- `model.reset_parameters()` (or your init routine)
Reference: `references/pytorch_fsdp2_tutorial.md` (migration guide shows this flow explicitly).
---
### 3) Apply `fully_shard()` bottom-up (wrapping policy = “apply where needed”)
**Do not** only call `fully_shard` on the topmost module.
Recommended sharding pattern for transformer-like models:
- iterate modules, `if isinstance(m, TransformerBlock): fully_shard(m, ...)`
- then `fully_shard(model, ...)`
Why:
- `fully_shard` forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.
Reference: `references/pytorch_fully_shard_api.md` (bottom-up requirement and why).
---
### 4) Configure `reshard_after_forward` for memory/perf trade-offs
Default behavior:
- `None` means `True` for non-root modules and `False` for root modules (good default).
Heuristics:
- If you’re memory-bound: keep defaults or force `True` on many blocks.
- If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often `False`).
- Advanced: use an `int` to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.
Reference: `references/pytorch_fully_shard_api.md` (full semantics).
---
### 5) Mixed precision & offload (optional but common)
FSDP2 uses:
- `mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)`
- `offload_policy=CPUOffloadPolicy()` if you want CPU offload
Rules of thumb:
- Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
- Keep `reduce_dtype` aligned with your gradient reduction expectations.
- If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.
Reference: `references/pytorch_fully_shard_api.md` (MixedPrecisionPolicy / OffloadPolicy classes).
---
### 6) Optimizer, gradient clipping, accumulation
- Create the optimizer **after** sharding so it holds DTensor params.
- If you need gradient accumulation / no_sync:
- use the FSDP2 mechanism (`set_requires_gradient_sync`) instead of FSDP1’s `no_sync()`.
Gradient clipping:
- Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.
Reference: `references/pytorch_fsdp2_tutorial.md`.
---
### 7) Checkpointing: prefer DCP or distributed state dict helpers
Two recommended approaches:
**A) Distributed Checkpoint (DCP) — best default**
- DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
- DCP produces **multiple files** (often at least one per rank) and operates “in place”.
**B) Distributed state dict helpers**
- `get_model_state_dict` / `set_model_state_dict` with `StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)`
- For optimizer: `get_optimizer_state_dict` / `set_optimizer_state_dict`
Avoid:
- Saving DTensor state dicts with plain `torch.save` unless you intentionally convert with `DTensor.full_tensor()` and manage memory carefully.
References:
- `references/pytorch_dcp_overview.md` (DCP behavior and caveats)
- `references/pytorch_dcp_recipe.md` and `references/pytorch_dcp_async_recipe.md` (end-to-end usage)
- `references/pytorch_fsdp2_tutorial.md` (DTensor vs DCP state-dict flows)
- `references/pytorch_examples_fsdp2.md` (working checkpoint scripts)
---
## Workflow checklists (copy-paste friendly)
### Workflow A: Retrofit FSDP2 into an existing training script
- [ ] Launch with `torchrun` and initialize the process group.
- [ ] Set the CUDA device from `LOCAL_RANK`; create a `DeviceMesh` if you need multi-dim parallelism.
- [ ] Build the model (use `meta` if needed), apply `fully_shard` bottom-up, then `fully_shard(model)`.
- [ ] Create the optimizer after sharding so it captures DTensor parameters.
- [ ] Use `model(inputs)` so hooks run; use `set_requires_gradient_sync` for accumulation.
- [ ] Add DCP save/load via `torch.distributed.checkpoint` helpers.
Reference: `references/pytorch_fsdp2_tutorial.md`, `references/pytorch_fully_shard_api.md`, `references/pytorch_device_mesh_tutorial.md`, `references/pytorch_dcp_recipe.md`.
### Workflow B: Add DCP save/load (minimal pattern)
- [ ] Wrap state in `Stateful` or assemble state via `get_state_dict`.
- [ ] Call `dcp.save(...)` from all ranks to a shared path.
- [ ] Call `dcp.load(...)` and restore with `set_state_dict`.
- [ ] Validate any resharding assumptions when loading into a different mesh.
Reference: `references/pytorch_dcp_recipe.md`.
## Debug checklist (what the agent should check first)
1. **All ranks on distinct GPUs?**
If not, verify `torch.cuda.set_device(LOCAL_RANK)` and your `torchrun` flags.
2. **Did you accidentally call `forward()` directly?**
Use `model(input)` or explicitly `unshard()` / register forward.
3. **Is `fully_shard()` applied bottom-up?**
If only root is sharded, expect worse memory/perf and possible confusion.
4. **Optimizer created at the right time?**
Must be built on DTensor parameters *after* sharding.
5. **Checkpointing path consistent?**
- If using DCP, don’t mix with ad-hoc `torch.save` unless you understand conversions.
- Be mindful of PyTorch-version compatibility warnings for DCP.
---
## Common issues and fixes
- **Forward hooks not running** → Call `model(inputs)` (or `unshard()` explicitly) instead of `model.forward(...)`.
- **Optimizer sees non-DTensor params** → Create optimizer after all `fully_shard` calls.
- **Only root module sharded** → Apply `fully_shard` bottom-up on submodules before the root.
- **Memory spikes after forward** → Set `reshard_after_forward=True` for more modules.
- **Gradient accumulation desync** → Use `set_requires_gradient_sync` instead of FSDP1’s `no_sync()`.
Reference: `references/pytorch_fully_shard_api.md`, `references/pytorch_fsdp2_tutorial.md`.
---
## Minimal reference implementation outline (agent-friendly)
The coding agent should implement a script with these labeled blocks:
- `init_distributed()`: init process group, set device
- `build_model_meta()`: model on meta, apply `fully_shard`, materialize weights
- `build_optimizer()`: optimizer created after sharding
- `train_step()`: forward/backward/step with `model(inputs)` and DTensor-aware patterns
- `checkpoint_save/load()`: DCP or distributed state dict helpers
Concrete examples live in `references/pytorch_examples_fsdp2.md` and the official tutorial reference.
---
## References
- `references/pytorch_fsdp2_tutorial.md`
- `references/pytorch_fully_shard_api.md`
- `references/pytorch_ddp_notes.md`
- `references/pytorch_fsdp1_api.md`
- `references/pytorch_device_mesh_tutorial.md`
- `references/pytorch_tp_tutorial.md`
- `references/pytorch_dcp_overview.md`
- `references/pytorch_dcp_recipe.md`
- `references/pytorch_dcp_async_recipe.md`
- `references/pytorch_examples_fsdp2.md`
- `references/torchtitan_fsdp_notes.md` (optional, production notes)
- `references/ray_train_fsdp2_example.md` (optional, integration example)
This skill adds PyTorch FSDP2 (fully_shard) to training scripts so very large models can train across multiple GPUs with correct init, sharding, mixed precision/offload options, and distributed checkpointing. It codifies the required launch pattern, bottom-up sharding, DTensor-aware optimizer creation, and recommended Distributed Checkpoint (DCP) flows. The goal is a reliable retrofit or new training loop that avoids common FSDP2 pitfalls and memory spikes.
The skill initializes torch.distributed and sets the CUDA device per process, optionally creates a DeviceMesh, and then applies fully_shard bottom-up (shard submodules first, then root). It supports meta-device model construction, materialization with to_empty/reset_parameters, DTensor-friendly optimizer creation after sharding, mixed-precision/offload policies, and DCP or distributed state-dict checkpointing. The runtime enforces using model(input) so FSDP2 hooks run and offers patterns for gradient accumulation and clipping with DTensors.
What if my forward hooks aren’t running?
Call model(inputs) rather than model.forward(inputs). If you must call forward directly, unshard or register the forward method explicitly so hooks run.
When should the optimizer be created?
Always create the optimizer after all fully_shard calls so it captures DTensor parameters; creating it earlier yields non-DTensor params and incorrect behavior.