diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 354f110f0b0d..9f8971a4faf5 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -63,6 +63,7 @@ def get_stage_index( stage: Optional[int] = None, num_model_chunks: Optional[int] = None, num_stages: Optional[int] = None, + use_zbv: bool = False, ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: """ Get the start index and end index of layers for each stage. @@ -78,6 +79,10 @@ def get_stage_index( - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk """ + if use_zbv: + assert ( + num_model_chunks == 2 + ), f"When you use the zero-bubble V scheduler, the number of model chunks should be equal to 2." stage = self.stage if stage is None else stage num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks num_stages = self.num_stages if num_stages is None else num_stages @@ -85,6 +90,16 @@ def get_stage_index( num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) stage_indices = [] + if use_zbv: + stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]]) + stage_indices.append( + [ + num_layers_per_stage_accumulated[2 * num_stages - stage - 1], + num_layers_per_stage_accumulated[2 * num_stages - stage], + ] + ) + return stage_indices + for model_chunk in range(num_model_chunks): start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]