Skip to content

Commit

Permalink
[zero] add L2 gradient clipping for ZeRO (#2112)
Browse files Browse the repository at this point in the history
* [zero] add L2 gradient clipping

* [testing] add MlpModel

* [zero] add unit test for grad clipping

* fix atol
  • Loading branch information
1SAA committed Dec 9, 2022
1 parent 70a8556 commit 63fbba3
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 11 deletions.
21 changes: 18 additions & 3 deletions colossalai/gemini/chunk/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def alloc_storage(tensor: torch.Tensor) -> None:


class Chunk:

_total_number = 0

def __init__(self,
Expand Down Expand Up @@ -140,6 +139,10 @@ def __init__(self,
# if the cpu_shard has been visited during the training step, the flag is True
self.cpu_vis_flag = False

# whether to record l2 norm for the gradient clipping calculation
self.l2_norm_flag = False
self.l2_norm = None

@property
def memory_usage(self) -> Dict[str, int]:
cuda_memory = 0
Expand Down Expand Up @@ -213,16 +216,28 @@ def can_reduce(self):

@property
def has_inf_or_nan(self) -> bool:
"""Check if the chunk has inf or nan values in CUDA.
"""Check if the chunk has inf or nan values on CUDA.
"""
if self.is_gathered:
valid_tensor = self.chunk_total[:self.utilized_size]
else:
assert self.cuda_shard is not None # only check in CUDA
assert self.cuda_shard is not None # only check on CUDA
valid_tensor = self.cuda_shard[:self.valid_end]

return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()

def set_l2_norm(self) -> None:
"""Record l2 norm of this chunks on CUDA.
"""
assert self.l2_norm is None, "you are calculating the l2 norm twice"
if self.is_gathered:
valid_tensor = self.chunk_total[:self.utilized_size]
else:
assert self.cuda_shard is not None # calculate on CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
chunk_l2_norm = valid_tensor.data.float().norm(2)
self.l2_norm = chunk_l2_norm.item()**2

def append_tensor(self, tensor: torch.Tensor):
"""Add a tensor to the chunk.
Expand Down
61 changes: 54 additions & 7 deletions colossalai/nn/optimizer/zero_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from enum import Enum
from typing import Any, Dict, Set, Tuple

Expand Down Expand Up @@ -56,6 +57,8 @@ def __init__(self,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
clipping_norm: float = 0.0,
norm_type: float = 2.0,
**defaults: Any):
super().__init__(optim)
assert isinstance(module, ZeroDDP)
Expand All @@ -66,11 +69,17 @@ def __init__(self,
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set()
self.clipping_flag = clipping_norm > 0.0
self.max_norm = clipping_norm

if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"

params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
for p, fp32_p in zip(params_list, module.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set:
chunk_16.l2_norm_flag = self.clipping_flag
self.chunk16_set.add(chunk_16)

self.__init__optimizer()
Expand Down Expand Up @@ -128,12 +137,45 @@ def _check_overflow(self):

return self._found_overflow.item() > 0

def _unscale_grads(self):
def _calc_global_norm(self) -> float:
norm_sqr: float = 0.0
group_to_norm = dict()
for c16 in self.chunk16_set:
assert c16.l2_norm is not None

if c16.is_gathered:
norm_sqr += c16.l2_norm
else:
# this chunk is sharded, use communication to collect total norm
if c16.torch_pg not in group_to_norm:
group_to_norm[c16.torch_pg] = 0.0
group_to_norm[c16.torch_pg] += c16.l2_norm

c16.l2_norm = None # clear l2 norm

comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
for group, part_norm in group_to_norm.items():
comm_buffer.fill_(part_norm)
dist.all_reduce(comm_buffer, group=group)
norm_sqr += comm_buffer.item()

global_norm = math.sqrt(norm_sqr)
return global_norm

def _unscale_and_clip_grads(self):
assert self.optim_state == OptimState.SCALED

combined_scale = self.loss_scale
if self.clipping_flag:
total_norm = self._calc_global_norm()
clip = ((total_norm / self.loss_scale) + 1e-6) / self.max_norm
if clip > 1:
combined_scale = clip * self.loss_scale

for group in self.optim.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.data.div_(self.loss_scale)
p.grad.data.div_(combined_scale)
self.optim_state = OptimState.UNSCALED

@property
Expand All @@ -147,16 +189,21 @@ def zero_grad(self, *args, **kwargs):
def step(self, *args, **kwargs):
self._maybe_move_fp32_params()
self._set_grad_ptr()
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._unscale_grads()

found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step')
self.zero_grad()
self.zero_grad() # reset all gradients
self._update_fp16_params()
return

# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._unscale_and_clip_grads()
self.grad_scaler.update(found_inf)

ret = self.optim.step(*args, **kwargs)
self._register_states()
self.zero_grad()
Expand Down
4 changes: 4 additions & 0 deletions colossalai/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,11 @@ def grad_handle(self, p, grad):
chunk.chunk_total.div_(chunk.pg_size)
else:
chunk.cuda_shard.div_(chunk.pg_size)
# check overflow elements
self.overflow_counter += chunk.has_inf_or_nan
# record l2 norm for gradient clipping
if chunk.l2_norm_flag:
chunk.set_l2_norm()
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad

Expand Down
117 changes: 117 additions & 0 deletions tests/test_gemini/update/test_grad_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from functools import partial
from time import time

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close

import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed


def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()

for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)


@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('model_name', ['gpt2'])
def exam_grad_clipping(placement_policy, model_name: str):
set_seed(1912)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

torch_model = model_builder().cuda()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])

init_dev = get_current_device()
with ColoInitContext(device=init_dev):
model = model_builder()

for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data)

world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)

optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)

model.train()
torch_model.train()

set_seed(dist.get_rank() * 3 + 128)
for i, (data, label) in enumerate(train_dataloader):
if i > 2:
break
data = data.cuda()
label = label.cuda()

zero_optim.zero_grad()
torch_optim.zero_grad()

torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim)
loss = run_fwd_bwd(model, data, label, criterion, zero_optim)
assert_close(torch_loss, loss)

import apex.amp as apex_amp
torch.nn.utils.clip_grad_norm_(apex_amp.master_params(torch_optim), 1.0)
torch_optim.step()
zero_optim.step()

check_param(model, torch_model)


def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_grad_clipping()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_grad_clip(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_grad_clip(2)
2 changes: 1 addition & 1 deletion tests/test_gemini/update/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2)
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)


@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
Expand Down

0 comments on commit 63fbba3

Please sign in to comment.