8000 Make Checkpoint.load_objects to accept str and load internally (#2303) by Abo7atm · Pull Request #2305 · pytorch/ignite · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Make Checkpoint.load_objects to accept str and load internally (#2303) #2305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,14 +507,14 @@ def _check_objects(objs: Mapping, attr: str) -> None:
raise TypeError(f"Object {type(obj)} should have `{attr}` method")

@staticmethod
def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: Any) -> None:
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``.

Args:
to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
checkpoint: a dictionary with state_dicts to load, e.g. `{"model": model_state_dict,
"optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain directly
corresponding state_dict.
checkpoint: a string filepath or a dictionary with state_dicts to load, e.g. `{"model": model_state_dict,
"optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain
directly corresponding state_dict.
kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
the user to load part of the pretrained model (useful for example, in Transfer Learning)

Expand All @@ -537,18 +537,29 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
checkpoint = torch.load(checkpoint_fp)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

# or using a string for checkpoint filepath

to_load = to_save
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)

Note:
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
`DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``).

.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
torch.nn.parallel.DistributedDataParallel.html
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html

"""

if isinstance(checkpoint, str):
checkpoint_obj = torch.load(checkpoint)
else:
checkpoint_obj = checkpoint

Checkpoint._check_objects(to_load, "load_state_dict")
if not isinstance(checkpoint, collections.Mapping):
raise TypeError(f"Argument checkpoint should be a dictionary, but given {type(checkpoint)}")
if not isinstance(checkpoint, (collections.Mapping, str)):
raise TypeError(f"Argument checkpoint should be a string or a dictionary, but given {type(checkpoint)}")

if len(kwargs) > 1 or any(k for k in kwargs if k not in ["strict"]):
warnings.warn("kwargs contains keys other than strict and these will be ignored")
Expand All @@ -557,22 +568,22 @@ def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs: Any) -> None:
if len(to_load) == 1:
# single object and checkpoint is directly a state_dict
key, obj = list(to_load.items())[0]
if key not in checkpoint:
if key not in checkpoint_obj:
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
obj.load_state_dict(checkpoint, strict=is_state_dict_strict)
obj.load_state_dict(checkpoint_obj, strict=is_state_dict_strict)
return

# multiple objects to load
for k, obj in to_load.items():
if k not in checkpoint:
if k not in checkpoint_obj:
raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint")
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
if isinstance(obj, torch.nn.Module):
obj.load_state_dict(checkpoint[k], strict=is_state_dict_strict)
obj.load_state_dict(checkpoint_obj[k], strict=is_state_dict_strict)
else:
obj.load_state_dict(checkpoint[k])
obj.load_state_dict(checkpoint_obj[k])

def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
"""Method returns state dict with saved items: list of ``(priority, filename)`` pairs.
Expand Down
14 changes: 13 additions & 1 deletion tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,7 @@ def test_save_model_optimizer_lr_scheduler_with_validation(dirname):

def test_checkpoint_load_objects():

with pytest.raises(TypeError, match=r"Argument checkpoint should be a dictionary"):
with pytest.raises(TypeError, match=r"Argument checkpoint should be a string or a dictionary"):
Checkpoint.load_objects({}, [])

with pytest.raises(TypeError, match=r"should have `load_state_dict` method"):
Expand Down Expand Up @@ -1107,6 +1107,17 @@ def _get_multiple_objs_to_save():
trainer = Engine(lambda e, b: None)
trainer.state = State(epoch=0, iteration=0)

# case: load from filepath
handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
to_save = _get_multiple_objs_to_save()
handler(trainer, to_save)
fname = handler.last_checkpoint
assert isinstance(fname, str)
assert os.path.join(dirname, _PREFIX) in fname
assert os.path.exists(fname)
Checkpoint.load_objects(to_save, fname)
os.remove(fname)

# case: multiple objects
handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
to_save = _get_multiple_objs_to_save()
Expand Down Expand Up @@ -1142,6 +1153,7 @@ def _get_multiple_objs_to_save():
assert os.path.exists(fname)
loaded_objects = torch.load(fname)
Checkpoint.load_objects(to_save, loaded_objects)
os.remove(fname)


def test_load_checkpoint_with_different_num_classes(dirname):
Expand Down
0