Skip to content

Commit

Permalink
adds epoch-boundary checkpoint saving (#160)
Browse files Browse the repository at this point in the history
Currently, we save checkpoints:

1. Whenever we pass enough samples, and
2. sometimes at the end of training.

This adds saving per-epoch, so one could set save_samples really high and ONLY
save at the boundary of epochs.

Signed-off-by: James Kunstle <[email protected]>
  • Loading branch information
JamesKunstle committed Aug 8, 2024
1 parent 3170300 commit 3a91777
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class TrainingArgs(BaseModel):
warmup_steps: int
is_padding_free: bool
random_seed: int = 42
checkpoint_at_epoch: bool = False

mock_data: Optional[bool] = False
mock_data_len: int = 0
Expand Down
17 changes: 17 additions & 0 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,15 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
if local_rank == 0:
inner_pb.update(1)
torch.cuda.empty_cache()

if args.checkpoint_at_epoch:
save_hf_format_ds(
args,
model,
tokenizer,
global_step * args.samples_per_gpu * world_size,
is_lora=bool(args.lora_r),
)
if args.save_last:
save_hf_format_ds(
args,
Expand Down Expand Up @@ -615,6 +624,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
f"--chat-tmpl-path={train_args.chat_tmpl_path}",
]

if train_args.checkpoint_at_epoch:
command.append("--checkpoint_at_epoch")

if train_args.mock_data:
command.append("--mock_data")
if train_args.mock_len:
Expand Down Expand Up @@ -734,6 +746,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
parser.add_argument(
"--save_last", action="store_true", help="save after finishing training"
)
parser.add_argument(
"--checkpoint_at_epoch",
action="store_true",
help="Save a model checkpoint after finishing an epoch.",
)
parser.add_argument("--log_level", type=str, default="INFO")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--mock_data", action="store_true")
Expand Down

0 comments on commit 3a91777

Please sign in to comment.