Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[zerobubble] support distributed layers for zero bubble v scheduler. #6035

Open
wants to merge 2 commits into
base: feature/zerobubble
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions colossalai/pipeline/stage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def get_stage_index(
stage: Optional[int] = None,
num_model_chunks: Optional[int] = None,
num_stages: Optional[int] = None,
use_zbv: bool = False,
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
"""
Get the start index and end index of layers for each stage.
Expand All @@ -78,13 +79,27 @@ def get_stage_index(
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk

"""
if use_zbv:
assert (
num_model_chunks == 2
), f"When you use the zero-bubble V scheduler, the number of model chunks should be equal to 2."
stage = self.stage if stage is None else stage
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
num_stages = self.num_stages if num_stages is None else num_stages

num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)

stage_indices = []
if use_zbv:
stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]])
stage_indices.append(
[
num_layers_per_stage_accumulated[2 * num_stages - stage - 1],
num_layers_per_stage_accumulated[2 * num_stages - stage],
]
)
return stage_indices

for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
Expand Down
Loading