Open
Description
🚀 The feature, motivation and pitch
I'm currently working on a transfer learning task and encountered an error while loading a pre-trained model. I modified some of the model layers but kept the name unchanged, and the error remained like this.
RuntimeError: Error(s) in loading state_dict for DataParallel:
size mismatch for module.conv1.weight: copying a param with shape torch.Size([64, 3, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).
size mismatch for module.fc.weight: copying a param with shape torch.Size([1000, 2048]) from checkpoint, the shape in current model is torch.Size([10, 2048]).
size mismatch for module.fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).
I know I can fix this by writing some code to filter unexpected keys. However, I wonder if it's possible to change this error to a warning when it occurs.
Such circumstances usually happen in transfer learning like:
- We apply an image backbone to a small input (from 224 to 32) dataset like CIFAR, in that case we need to modify to stem conv to keep the output shape unchanged.
- Using a pre-trained Detection or Segmentation model to a different dataset with different categories, and we need to change to the prediction layer.
Alternatives
Something like this code works, however if Pytorch naive handles it with warnings is also a concise solution. I believe people know what they are doing with strict=False
.
for k in state_dict:
if k in model_state_dict:
if state_dict[k].shape != model_state_dict[k].shape:
print('Skip loading parameter {}, required shape{}, ' \
'loaded shape{}. {}'.format(
k, model_state_dict[k].shape, state_dict[k].shape, msg))
if 'class_embed' in k:
print("load class_embed: {} shape={}".format(k, state_dict[k].shape))
if model_state_dict[k].shape[0] == 1:
state_dict[k] = state_dict[k][1:2]
elif model_state_dict[k].shape[0] == 2:
state_dict[k] = state_dict[k][1:3]
elif model_state_dict[k].shape[0] == 3:
state_dict[k] = state_dict[k][1:4]
else:
raise NotImplementedError('invalid shape: {}'.format(model_state_dict[k].shape))
continue
state_dict[k] = model_state_dict[k]
Additional context
No response
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
To pick up