FSDP (Fully Sharded Data Parallel)
FSDP (Fully Sharded Data Parallel) splits the model parameters, gradients, and optimizer states on top of data parallelism. It breaks down the all-reduce communication operation into reduce-scatter and all-gather, thereby reducing the peak memory usage of individual parallel workers. This allows training on larger models or with larger micro-batch sizes.
Let’s demonstrate how to accelerate FSDP training using TorchAcc optimization with a simple example.
Torch Native Task
Below is the code for GPT2 in Torch which training with bfloat16:
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
# Model and tokenizer setup
model_name = 'gpt2'
max_length = 512
batch_size = 16
model_config = AutoConfig.from_pretrained(model_name, cache_dir='./log/model_cache')
model = AutoModelForCausalLM.from_config(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
tokenizer.model_max_length = max_length
tokenizer.pad_token = tokenizer.eos_token
# Dataset and dataloader
def preprocess_function(examples):
examples['text'] = [text for text in examples['text'] if len(text) > 0]
tokenized = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length)
tokenized['labels'] = tokenized['input_ids'].copy()
return tokenized
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train').map(preprocess_function, batched=True)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
device = 'cuda:0'
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
model.train()
for step, batch in enumerate(tqdm(train_dataloader, unit='batch')):
optimizer.zero_grad()
batch = {k: v.to(device) for k, v in batch.items()}
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
loss = model(**batch)['loss']
loss.backward()
optimizer.step()
if step % 100 == 0:
print(f'step: {step}, loss: {loss.item():.4f}')
FSDP
You only need to configure the TorchAcc Config and pass it to the torchacc.accelerate function to easily achieve FSDP training.
import torch
+ import torchacc
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
# Model and tokenizer setup
model_name = 'gpt2'
max_length = 512
batch_size = 16
model_config = AutoConfig.from_pretrained(model_name, cache_dir='./log/model_cache')
model = AutoModelForCausalLM.from_config(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
tokenizer.model_max_length = max_length
tokenizer.pad_token = tokenizer.eos_token
# Dataset and dataloader
def preprocess_function(examples):
examples['text'] = [text for text in examples['text'] if len(text) > 0]
tokenized = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length)
tokenized['labels'] = tokenized['input_ids'].copy()
return tokenized
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train').map(preprocess_function, batched=True)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
- device = 'cuda:0'
- model.to(device)
+ config = torchacc.Config()
+ config.compute.bf16 = True
+ config.dist.fsdp.size = 4
+ config.dist.fsdp.wrap_layer_cls = {'GPT2Block'}
+ model, train_dataloader = torchacc.accelerate(model, train_dataloader, config)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
model.train()
for step, batch in enumerate(tqdm(train_dataloader, unit='batch')):
optimizer.zero_grad()
- batch = {k: v.to(device) for k, v in batch.items()}
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
loss = model(**batch)['loss']
loss.backward()
optimizer.step()
if step % 100 == 0:
print(f'step: {step}, loss: {loss.item():.4f}')
The main changes:
Removed the code for moving the model to the CUDA device.
Configured FSDP and compute information through torchacc
Config.Used the
torchacc.accelerateinterface to wrap and accelerate themodelandtrain_dataloader.
The shell command for running FSDP tasks is the same as data parallelism:
$ torchrun --nproc_per_node=4 resnet_acc.py
Checkpoint Save/Load
Save Checkpoint
Save the model parameters for each FSDP shard, optimizer states, and LR scheduler. Note that you need to save shard_metadata to restore the correct shard information.
# 1) Save model shards
torchacc.dist.rendezvous("saving_model")
torchacc.dist.mark_step()
ckpt = {
'model': model.state_dict(),
'shard_metadata': model.get_shard_metadata(),
}
torchacc.save(ckpt, CKPT_DIR, master_only=False)
# 2) Save optimizer states and LR scheduler
torchacc.dist.rendezvous("saving_optimizer_states")
torchacc.save(optimizer.state_dict(), OPTIMIZER_DIR)
torchacc.save(lr_scheduler.state_dict(), LR_SCHEDULER_DIR)
Load from Checkpoint
# 1) Reorganize shards
if torchacc.dist.is_master_ordinal(local=False):
torchacc.dist.fsdp.consolidate_sharded_model_checkpoints(
CKPT_DIR, ckpt_suffix)
torchacc.dist.rendezvous("ckpt_consolidation")
# 2) Load model
ckpt_consolidated = torch.load("consolidated.pth")
model.load_state_dict(ckpt_consolidated['model'])
# 3) Load optimizer states and LR scheduler
optimizer_state = torch.load(OPTIMIZER_DIR)
lr_scheduler_state = torch.load(LR_SCHEDULER_DIR)
optimizer.load_state_dict(optimizer_state)
lr_scheduler.load_state_dict(lr_scheduler_state)
Configurable parameters
The configurable parameters of FSDP are as follows:
Parameter |
Type |
Description |
|---|---|---|
size |
int |
Number of fully sharded data parallel. |
wrap_layer_cls |
Set[str] |
Submodules with one of the |
flatten_parameters |
bool |
If |
sync_module_states |
bool |
If |
use_spmd |
bool |
If |
shard_output_callable |
callable |
A callable to shard the output of the forward pass. The callable should have the signature (output, mesh) -> None. If None, the default implementation will shard the first tensor in the output. If the output is a tuple, only the first tensor will be sharded. |