From 3596bd60e765ed4366a51e12b402725c46877122 Mon Sep 17 00:00:00 2001 From: Jianbing Wu <50580578+KimbingNg@users.noreply.github.com> Date: Fri, 22 Mar 2024 19:52:40 +0800 Subject: [PATCH] Fixes bug by raising exception on size mismatch --- colossalai/checkpoint_io/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index e1800f29b0af..47f0ba0afc0e 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -558,6 +558,10 @@ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True) missing_keys = missing_keys.append(sub_missing_keys) if strict: + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) + ) if len(unexpected_keys) > 0: error_msgs = "Unexpected key(s) in state_dict: {}. ".format( ", ".join('"{}"'.format(k) for k in unexpected_keys)