FSDP (Fully Sharded Data Parallel)

FSDP (Fully Sharded Data Parallel) splits the model parameters, gradients, and optimizer states on top of data parallelism. It breaks down the all-reduce communication operation into reduce-scatter and all-gather, thereby reducing the peak memory usage of individual parallel workers. This allows training on larger models or with larger micro-batch sizes.

Let’s demonstrate how to accelerate FSDP training using TorchAcc optimization with a simple example.

Torch Native Task

Below is the code for GPT2 in Torch which training with bfloat16:

import torch
import torchacc
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

# Model and tokenizer setup
model_name = 'gpt2'
max_length = 512
batch_size = 16

model_config = AutoConfig.from_pretrained(model_name, cache_dir='./log/model_cache')
model = AutoModelForCausalLM.from_config(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
tokenizer.model_max_length = max_length
tokenizer.pad_token = tokenizer.eos_token

# Dataset and dataloader
def preprocess_function(examples):
    examples['text'] = [text for text in examples['text'] if len(text) > 0]
    tokenized = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length)
    tokenized['labels'] = tokenized['input_ids'].copy()
    return tokenized

dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train').map(preprocess_function, batched=True)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

device = 'cuda:0'
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

model.train()
for step, batch in enumerate(tqdm(train_dataloader, unit='batch')):
    optimizer.zero_grad()
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        loss = model(**batch)['loss']
    loss.backward()
    optimizer.step()
    if step % 100 == 0:
        print(f'step: {step}, loss: {loss.item():.4f}')

FSDP

You only need to configure the TorchAcc Config and pass it to the torchacc.accelerate function to easily achieve FSDP training.

  import torch
  import torchacc
  from datasets import load_dataset
  from tqdm import tqdm
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

  # Model and tokenizer setup
  model_name = 'gpt2'
  max_length = 512
  batch_size = 16

  model_config = AutoConfig.from_pretrained(model_name, cache_dir='./log/model_cache')
  model = AutoModelForCausalLM.from_config(model_config)
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
  tokenizer.model_max_length = max_length
  tokenizer.pad_token = tokenizer.eos_token

  # Dataset and dataloader
  def preprocess_function(examples):
      examples['text'] = [text for text in examples['text'] if len(text) > 0]
      tokenized = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length)
      tokenized['labels'] = tokenized['input_ids'].copy()
      return tokenized

  dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train').map(preprocess_function, batched=True)
  dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
  train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

- device = 'cuda:0'
- model.to(device)
+ config = torchacc.Config()
+ config.compute.bf16 = True
+ config.dist.fsdp.size = 4
+ config.dist.fsdp.wrap_layer_cls = {'GPT2Block'}
+ model, train_dataloader = torchacc.accelerate(model, train_dataloader, config)
  optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

  model.train()
  for step, batch in enumerate(tqdm(train_dataloader, unit='batch')):
      optimizer.zero_grad()
-     batch = {k: v.to(device) for k, v in batch.items()}
      with torch.cuda.amp.autocast(dtype=torch.bfloat16):
          loss = model(**batch)['loss']
      loss.backward()
      optimizer.step()
      if step % 100 == 0:
          print(f'step: {step}, loss: {loss.item():.4f}')

The main changes:

  • Removed the code for moving the model to the CUDA device.

  • Configured FSDP and compute information through torchacc Config.

  • Used the torchacc.accelerate interface to wrap and accelerate the model and train_dataloader.

The shell command for running FSDP tasks is the same as data parallelism:

$ torchrun --nproc_per_node=4 resnet_acc.py

Checkpoint Save/Load

Save Checkpoint

Save the model parameters for each FSDP shard, optimizer states, and LR scheduler. Note that you need to save shard_metadata to restore the correct shard information.

# 1) Save model shards
torchacc.dist.rendezvous("saving_model")
torchacc.dist.mark_step()
ckpt = {
    'model': model.state_dict(),
    'shard_metadata': model.get_shard_metadata(),
}
torchacc.save(ckpt, CKPT_DIR, master_only=False)

# 2) Save optimizer states and LR scheduler
torchacc.dist.rendezvous("saving_optimizer_states")
torchacc.save(optimizer.state_dict(), OPTIMIZER_DIR)
torchacc.save(lr_scheduler.state_dict(), LR_SCHEDULER_DIR)

Load from Checkpoint

# 1) Reorganize shards
if torchacc.dist.is_master_ordinal(local=False):
    torchacc.dist.fsdp.consolidate_sharded_model_checkpoints(
        CKPT_DIR, ckpt_suffix)
torchacc.dist.rendezvous("ckpt_consolidation")

# 2) Load model
ckpt_consolidated = torch.load("consolidated.pth")
model.load_state_dict(ckpt_consolidated['model'])

# 3) Load optimizer states and LR scheduler
optimizer_state = torch.load(OPTIMIZER_DIR)
lr_scheduler_state = torch.load(LR_SCHEDULER_DIR)
optimizer.load_state_dict(optimizer_state)
lr_scheduler.load_state_dict(lr_scheduler_state)

Configurable parameters

The configurable parameters of FSDP are as follows:

Parameter

Type

Description

size

int

Number of fully sharded data parallel.

wrap_layer_cls

Set[str]

Submodules with one of the wrap_layer_cls names will be wrapped as separated FSDP units.

flatten_parameters

bool

If True, flatten parameters into a single contiguous tensor for all_gather and reduce_scatter, which could potentially improve speed. In this case, one cannot apply separate optimizer groups to different original parameters in the wrapped module (e.g. setting bias terms or any BatchNorm submodules to have zero weight decay) since all the original parameters now become a single concatenated vector.

sync_module_states

bool

If True, then each FSDP module will broadcast module parameters and buffers from rank 0 to ensure that they are replicated across ranks (adding communication overhead and more GPU memory overhead during initialization).

use_spmd

bool

If True, use SPMD based FSDP.

shard_output_callable

callable

A callable to shard the output of the forward pass. The callable should have the signature (output, mesh) -> None. If None, the default implementation will shard the first tensor in the output. If the output is a tuple, only the first tensor will be sharded.