diff --git a/colossalai/gemini/chunk/chunk.py b/colossalai/gemini/chunk/chunk.py index a9f0f7eae2cf..d50565749c89 100644 --- a/colossalai/gemini/chunk/chunk.py +++ b/colossalai/gemini/chunk/chunk.py @@ -51,7 +51,6 @@ def alloc_storage(tensor: torch.Tensor) -> None: class Chunk: - _total_number = 0 def __init__(self, @@ -140,6 +139,10 @@ def __init__(self, # if the cpu_shard has been visited during the training step, the flag is True self.cpu_vis_flag = False + # whether to record l2 norm for the gradient clipping calculation + self.l2_norm_flag = False + self.l2_norm = None + @property def memory_usage(self) -> Dict[str, int]: cuda_memory = 0 @@ -213,16 +216,28 @@ def can_reduce(self): @property def has_inf_or_nan(self) -> bool: - """Check if the chunk has inf or nan values in CUDA. + """Check if the chunk has inf or nan values on CUDA. """ if self.is_gathered: valid_tensor = self.chunk_total[:self.utilized_size] else: - assert self.cuda_shard is not None # only check in CUDA + assert self.cuda_shard is not None # only check on CUDA valid_tensor = self.cuda_shard[:self.valid_end] return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() + def set_l2_norm(self) -> None: + """Record l2 norm of this chunks on CUDA. + """ + assert self.l2_norm is None, "you are calculating the l2 norm twice" + if self.is_gathered: + valid_tensor = self.chunk_total[:self.utilized_size] + else: + assert self.cuda_shard is not None # calculate on CUDA + valid_tensor = self.cuda_shard[:self.valid_end] + chunk_l2_norm = valid_tensor.data.float().norm(2) + self.l2_norm = chunk_l2_norm.item()**2 + def append_tensor(self, tensor: torch.Tensor): """Add a tensor to the chunk. diff --git a/colossalai/nn/optimizer/zero_optimizer.py b/colossalai/nn/optimizer/zero_optimizer.py index 09ecbb2c714b..62a0be329dd0 100644 --- a/colossalai/nn/optimizer/zero_optimizer.py +++ b/colossalai/nn/optimizer/zero_optimizer.py @@ -1,3 +1,4 @@ +import math from enum import Enum from typing import Any, Dict, Set, Tuple @@ -56,6 +57,8 @@ def __init__(self, growth_interval: int = 1000, hysteresis: int = 2, max_scale: float = 2**32, + clipping_norm: float = 0.0, + norm_type: float = 2.0, **defaults: Any): super().__init__(optim) assert isinstance(module, ZeroDDP) @@ -66,11 +69,17 @@ def __init__(self, self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() self.param_to_chunk32: Dict[Parameter, Chunk] = dict() self.chunk16_set: Set[Chunk] = set() + self.clipping_flag = clipping_norm > 0.0 + self.max_norm = clipping_norm + + if self.clipping_flag: + assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)] for p, fp32_p in zip(params_list, module.fp32_params): chunk_16 = self.chunk_manager.get_chunk(p) if chunk_16 not in self.chunk16_set: + chunk_16.l2_norm_flag = self.clipping_flag self.chunk16_set.add(chunk_16) self.__init__optimizer() @@ -128,12 +137,45 @@ def _check_overflow(self): return self._found_overflow.item() > 0 - def _unscale_grads(self): + def _calc_global_norm(self) -> float: + norm_sqr: float = 0.0 + group_to_norm = dict() + for c16 in self.chunk16_set: + assert c16.l2_norm is not None + + if c16.is_gathered: + norm_sqr += c16.l2_norm + else: + # this chunk is sharded, use communication to collect total norm + if c16.torch_pg not in group_to_norm: + group_to_norm[c16.torch_pg] = 0.0 + group_to_norm[c16.torch_pg] += c16.l2_norm + + c16.l2_norm = None # clear l2 norm + + comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) + for group, part_norm in group_to_norm.items(): + comm_buffer.fill_(part_norm) + dist.all_reduce(comm_buffer, group=group) + norm_sqr += comm_buffer.item() + + global_norm = math.sqrt(norm_sqr) + return global_norm + + def _unscale_and_clip_grads(self): assert self.optim_state == OptimState.SCALED + + combined_scale = self.loss_scale + if self.clipping_flag: + total_norm = self._calc_global_norm() + clip = ((total_norm / self.loss_scale) + 1e-6) / self.max_norm + if clip > 1: + combined_scale = clip * self.loss_scale + for group in self.optim.param_groups: for p in group['params']: if p.grad is not None: - p.grad.data.div_(self.loss_scale) + p.grad.data.div_(combined_scale) self.optim_state = OptimState.UNSCALED @property @@ -147,16 +189,21 @@ def zero_grad(self, *args, **kwargs): def step(self, *args, **kwargs): self._maybe_move_fp32_params() self._set_grad_ptr() - # unscale grads if scaled - if self.optim_state == OptimState.SCALED: - self._unscale_grads() + found_inf = self._check_overflow() - self.grad_scaler.update(found_inf) if found_inf: + self.optim_state = OptimState.UNSCALED # no need to unscale grad + self.grad_scaler.update(found_inf) # update gradient scaler self._logger.info(f'Found overflow. Skip step') - self.zero_grad() + self.zero_grad() # reset all gradients self._update_fp16_params() return + + # unscale grads if scaled + if self.optim_state == OptimState.SCALED: + self._unscale_and_clip_grads() + self.grad_scaler.update(found_inf) + ret = self.optim.step(*args, **kwargs) self._register_states() self.zero_grad() diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 175146ebb158..ca937ff932cf 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -302,7 +302,11 @@ def grad_handle(self, p, grad): chunk.chunk_total.div_(chunk.pg_size) else: chunk.cuda_shard.div_(chunk.pg_size) + # check overflow elements self.overflow_counter += chunk.has_inf_or_nan + # record l2 norm for gradient clipping + if chunk.l2_norm_flag: + chunk.set_l2_norm() self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) return empty_grad diff --git a/tests/test_gemini/update/test_grad_clip.py b/tests/test_gemini/update/test_grad_clip.py new file mode 100644 index 000000000000..185521edb357 --- /dev/null +++ b/tests/test_gemini/update/test_grad_clip.py @@ -0,0 +1,117 @@ +from functools import partial +from time import time + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer +from colossalai.nn.parallel import ZeroDDP +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed + + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module): + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + # key is 'module.model.PARAMETER', so we truncate it + key = key[7:] + if key == 'model.lm_head.weight': + continue + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) + assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('model_name', ['gpt2']) +def exam_grad_clipping(placement_policy, model_name: str): + set_seed(1912) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + torch_model = model_builder().cuda() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + init_dev = get_current_device() + with ColoInitContext(device=init_dev): + model = model_builder() + + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + p.data.copy_(torch_p.data) + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0) + + model.train() + torch_model.train() + + set_seed(dist.get_rank() * 3 + 128) + for i, (data, label) in enumerate(train_dataloader): + if i > 2: + break + data = data.cuda() + label = label.cuda() + + zero_optim.zero_grad() + torch_optim.zero_grad() + + torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim) + loss = run_fwd_bwd(model, data, label, criterion, zero_optim) + assert_close(torch_loss, loss) + + import apex.amp as apex_amp + torch.nn.utils.clip_grad_norm_(apex_amp.master_params(torch_optim), 1.0) + torch_optim.step() + zero_optim.step() + + check_param(model, torch_model) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_grad_clipping() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_grad_clip(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_grad_clip(2) diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index f9d51ea79aed..f9333f3d1ccb 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -42,7 +42,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) - assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2) + assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])