-
-
Notifications
You must be signed in to change notification settings - Fork 649
Improved max_iters handling #1565
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
@@ -118,18 +117,18 @@ def compute_mean_std(engine, batch): | |||
|
|||
""" | |||
|
|||
_state_dict_all_req_keys = ("epoch_length", "max_epochs") | |||
_state_dict_one_of_opt_keys = ("iteration", "epoch") | |||
_state_dict_all_req_keys = ("epoch_length",) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if we interrupted the engine during the first epoch, we would not have epoch_length
.
_state_dict_all_req_keys = ("epoch_length", "max_epochs") | ||
_state_dict_one_of_opt_keys = ("iteration", "epoch") | ||
_state_dict_all_req_keys = ("epoch_length",) | ||
_state_dict_one_of_opt_keys = (("iteration", "epoch",), ("max_epochs", "max_iters",)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can relax the constraint that the both max_epochs
and max_iters
could not have values.
@@ -545,7 +541,7 @@ def load_state_dict(self, state_dict: Mapping) -> None: | |||
self.state.epoch = 0 | |||
if self.state.epoch_length is not None: | |||
self.state.epoch = self.state.iteration // self.state.epoch_length | |||
elif "epoch" in state_dict: | |||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to state_dict
this should never happen.
self.state.max_iters = max_iters | ||
|
||
def _check_and_set_epoch_length(self, data: Iterable, epoch_length: Optional[int] = None) -> None: | ||
# Can't we accept a redefinition ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can't because the relationship between epoch and iteration becomes invalid then. By the way in the case epoch_length
is changed, we could reinitialize the state (with epoch and iteration zero along with a message to the user) as well.
@@ -873,3 +903,14 @@ def _run_once_on_dataset(self) -> float: | |||
self._handle_exception(e) | |||
|
|||
return time.time() - start_time | |||
|
|||
def debug(self, enabled: bool = True) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you still want to merge this method into the master?
Fixes #1521
Description:
Engine.debug()
methodCheck list: