8000 Change error to warnning when tensor size mismatch while loading params · Issue #155368 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Change error to warnning when tensor size mismatch while loading params #155368
Open
@PistonY

Description

@PistonY

🚀 The feature, motivation and pitch

I'm currently working on a transfer learning task and encountered an error while loading a pre-trained model. I modified some of the model layers but kept the name unchanged, and the error remained like this.

RuntimeError: Error(s) in loading state_dict for DataParallel:
        size mismatch for module.conv1.weight: copying a param with shape torch.Size([64, 3, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).
        size mismatch for module.fc.weight: copying a param with shape torch.Size([1000, 2048]) from checkpoint, the shape in current model is torch.Size([10, 2048]).
        size mismatch for module.fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).

I know I can fix this by writing some code to filter unexpected keys. However, I wonder if it's possible to change this error to a warning when it occurs.
Such circumstances usually happen in transfer learning like:

  • We apply an image backbone to a small input (from 224 to 32) dataset like CIFAR, in that case we need to modify to stem conv to keep the output shape unchanged.
  • Using a pre-trained Detection or Segmentation model to a different dataset with different categories, and we need to change to the prediction layer.

Alternatives

Something like this code works, however if Pytorch naive handles it with warnings is also a concise solution. I believe people know what they are doing with strict=False.

for k in state_dict:
        if k in model_state_dict:
            if state_dict[k].shape != model_state_dict[k].shape:
                print('Skip loading parameter {}, required shape{}, ' \
                      'loaded shape{}. {}'.format(
                    k, model_state_dict[k].shape, state_dict[k].shape, msg))
                if 'class_embed' in k:
                    print("load class_embed: {} shape={}".format(k, state_dict[k].shape))
                    if model_state_dict[k].shape[0] == 1:
                        state_dict[k] = state_dict[k][1:2]
                    elif model_state_dict[k].shape[0] == 2:
                        state_dict[k] = state_dict[k][1:3]
                    elif model_state_dict[k].shape[0] == 3:
                        state_dict[k] = state_dict[k][1:4]
                    else:
                        raise NotImplementedError('invalid shape: {}'.format(model_state_dict[k].shape))
                    continue
                state_dict[k] = model_state_dict[k]

Additional context

No response

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    has workaroundmodule: nnRelated to torch.nnneeds researchWe need to decide whether or not this merits inclusion, based on research worldtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    To pick up

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0