FSDP (Fully Sharded Data Parallel)
FSDP (Fully Sharded Data Parallel) splits the model parameters, gradients, and optimizer states on top of data parallelism. It breaks down the all-reduce communication operation into reduce-scatter and all-gather, thereby reducing the peak memory usage of individual parallel workers. This allows training on larger models or with larger micro-batch sizes.
Let’s demonstrate how to accelerate FSDP training using TorchAcc optimization with a simple example.
Torch Native Task
Below is the code for GPT2 in Torch which training with bfloat16:
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
# Model and tokenizer setup
model_name = 'gpt2'
max_length = 512
batch_size = 16
model_config = AutoConfig.from_pretrained(model_name, cache_dir='./log/model_cache')
model = AutoModelForCausalLM.from_config(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
tokenizer.model_max_length = max_length
tokenizer.pad_token = tokenizer.eos_token
# Dataset and dataloader
def preprocess_function(examples):
examples['text'] = [text for text in examples['text'] if len(text) > 0]
tokenized = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length)
tokenized['labels'] = tokenized['input_ids'].copy()
return tokenized
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train').map(preprocess_function, batched=True)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
device = 'cuda:0'
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
model.train()
for step, batch in enumerate(tqdm(train_dataloader, unit='batch')):
optimizer.zero_grad()
batch = {k: v.to(device) for k, v in batch.items()}
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
loss = model(**batch)['loss']
loss.backward()
optimizer.step()
if step % 100 == 0:
print(f'step: {step}, loss: {loss.item():.4f}')
FSDP
You only need to configure the TorchAcc Config and pass it to the torchacc.accelerate function to easily achieve FSDP training.
import torch
+ import torchacc
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
# Model and tokenizer setup
model_name = 'gpt2'
max_length = 512
batch_size = 16
model_config = AutoConfig.from_pretrained(model_name, cache_dir='./log/model_cache')
model = AutoModelForCausalLM.from_config(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
tokenizer.model_max_length = max_length
tokenizer.pad_token = tokenizer.eos_token
# Dataset and dataloader
def preprocess_function(examples):
examples['text'] = [text for text in examples['text'] if len(text) > 0]
tokenized = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length)
tokenized['labels'] = tokenized['input_ids'].copy()
return tokenized
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train').map(preprocess_function, batched=True)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
- device = 'cuda:0'
- model.to(device)
+ config = torchacc.Config()
+ config.compute.bf16 = True
+ config.dist.fsdp.size = 4
+ config.dist.fsdp.wrap_layer_cls = {'GPT2Block'}
+ model, train_dataloader = torchacc.accelerate(model, train_dataloader, config)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
model.train()
for step, batch in enumerate(tqdm(train_dataloader, unit='batch')):
optimizer.zero_grad()
- batch = {k: v.to(device) for k, v in batch.items()}
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
loss = model(**batch)['loss']
loss.backward()
optimizer.step()
if step % 100 == 0:
print(f'step: {step}, loss: {loss.item():.4f}')
The main changes:
Removed the code for moving the model to the CUDA device.
Configured FSDP and compute information through torchacc
Config.Used the
torchacc.accelerateinterface to wrap and accelerate themodelandtrain_dataloader.
The shell command for running FSDP tasks is the same as data parallelism:
$ torchrun --nproc_per_node=4 gpt2_acc.py
Checkpoint Save/Load
Save Checkpoint
Save the model parameters and optimizer states for each FSDP shard and LR scheduler. Note that you need to save shard_metadata to restore the correct shard information.
shard_meta_data = model.model.model.get_shard_metadata()
CKP_DIR="./ckpt_dir"
MODEL_NAME=f"rank{torchacc.dist.local_rank()}-of-{torchacc.dist.world_size()}-model.pth"
OPTIM_NAME=f"rank{torchacc.dist.local_rank()}-of-{torchacc.dist.world_size()}-optim.pth"
# 1) Each rank save model shard
torchacc.dist.rendezvous("saving_model")
model_ckpt = {
'model': model.state_dict(),
'shard_metadata': shard_meta_data,
}
torchacc.save(model_ckpt, os.path.join(CKPT_DIR, MODEL_NAME), master_only=False)
# 2) Each rank save optimizer shard
torchacc.dist.rendezvous("saving_optimizer_states")
optim_ckpt = {
'optimizer': optimizer.state_dict(),
'shard_metadata': shard_meta_data,
}
torchacc.save(optim_ckpt, os.path.join(CKPT_DIR, OPTIM_NAME), master_only=False)
# 3) Save lr_scheduler
torchacc.save(lr_scheduler.state_dict(), LR_SCHEDULER_DIR)
Load Checkpoint
We can load from the shard ckpts and continue training if the fsdp config do not change. For example, we can save with fsdp_size = 4 and load with fsdp_size = 4.
CKPT_DIR="./ckpt_dir"
MODEL_NAME=f"rank{torchacc.dist.local_rank()}-of-{torchacc.dist.world_size()}-model.pth"
OPTIM_NAME=f"rank{torchacc.dist.local_rank()}-of-{torchacc.dist.world_size()}-optim.pth"
model_ckpt = torch.load(os.path.join(CKPT_DIR, MODEL_NAME))
model.load_state_dict(model_ckpt['model'])
optim_ckpt = torch.load(os.path.join(CKPT_DIR, OPTIM_NAME))
optimizer.load_state_dict(optim_ckpt['optimizer'])
Offline Consolidation and Reshard
We now support offline consolidate and reshard fsdp checkpoints. For example, you can save shard ckpt with fsdp_size = 4, and offline consolidate the shard checkpoints to a full checkpoint and then load the full checkpoint. What’s more, you can reshard the ckpts to 8, and then load the ckpts shardly with new fsdp config: fsdp_size=8.
You can run consolidate_and_reshard_fsdp_ckpts --help for more instructions.
# consolidate model and optimizer
consolidate_and_reshard_fsdp_ckpts --ckpt_dir CKPT_DIR --model_ckpt_name_pattern "rank*-of-*-model.pth" --optimizer_ckpt_name_pattern "rank*-of-*-optim.pth"
# you can use --reshard_num to reshard the fsdp checkpoints
consolidate_and_reshard_fsdp_ckpts --ckpt_dir CKPT_DIR --model_ckpt_name_pattern "rank*-of-*-model.pth" --optimizer_ckpt_name_pattern "rank*-of-*-optim.pth" --reshard_num 8
Load from Full Checkpoint
# 1) Load model
model_consolidated = torch.load("model_consolidated.pth") # the default consolidate model name
model.load_state_dict(model_consolidated)
# 2) Load optimizer
optimizer_consolidated = torch.load("optimizer_consolidated.pth") # the defualt consolidate optimizer name
optimizer.load_state_dict(optimizer_consolidated)
# 3) Load LR scheduler
lr_scheduler_state = torch.load(LR_SCHEDULER_DIR)
lr_scheduler.load_state_dict(lr_scheduler_state)
Configurable parameters
The configurable parameters of FSDP are as follows:
Parameter |
Type |
Description |
|---|---|---|
size |
int |
Number of fully sharded data parallel. |
wrap_layer_cls |
Set[str] |
Submodules with one of the |
flatten_parameters |
bool |
If |
sync_module_states |
bool |
If |
use_spmd |
bool |
If |
shard_output_callable |
callable |
A callable to shard the output of the forward pass. The callable should have the signature (output, mesh) -> None. If None, the default implementation will shard the first tensor in the output. If the output is a tuple, only the first tensor will be sharded. |