8000 Fix TrainsSaver handling of Checkpoint's n_saved by jkhenning · Pull Request #1135 · pytorch/ignite · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Fix TrainsSaver handling of Checkpoint's n_saved #1135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/contrib/mnist/mnist_with_trains_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def compute_metrics(engine):

handler = Checkpoint(
{"model": model},
TrainsSaver(trains_logger, dirname="~/.trains/cache/"),
TrainsSaver(trains_logger, dirname="./TRAINS"),
n_saved=1,
score_function=lambda e: e.state.metrics["accuracy"],
score_name="val_acc",
Expand Down
138 changes: 106 additions & 32 deletions ignite/contrib/handlers/trains_logger.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import tempfile
import warnings
from collections import defaultdict
from datetime import datetime
from typing import Mapping, Optional
from enum import Enum
from typing import Any, List, Mapping, Optional, Type

import torch

Expand Down Expand Up @@ -502,7 +504,8 @@ class TrainsLogger(BaseLogger):

def __init__(self, *_, **kwargs):
try:
import trains
from trains import Task
from trains.binding.frameworks.tensorflow_bind import WeightsGradientHistHelper
except ImportError:
raise RuntimeError(
"This contrib module requires trains to be installed. "
Expand All @@ -528,18 +531,16 @@ def __setattr__(self, attr, val):

self._task = _Stub()
else:
self._task = trains.Task.init(
self._task = Task.init(
project_name=kwargs.get("project_name"),
task_name=kwargs.get("task_name"),
task_type=kwargs.get("task_type", trains.Task.TaskTypes.training),
task_type=kwargs.get("task_type", Task.TaskTypes.training),
**experiment_kwargs,
)

self.trains_logger = self._task.get_logger()

self.grad_helper = trains.binding.frameworks.tensorflow_bind.WeightsGradientHistHelper(
logger=self.trains_logger,
)
self.grad_helper = WeightsGradientHistHelper(logger=self.trains_logger,)

@classmethod
def set_bypass_mode(cls, bypass: bool) -> None:
Expand Down Expand Up @@ -634,6 +635,8 @@ def __init__(self, logger: TrainsLogger = None, output_uri: str = None, dirname:
if "atomic" not in kwargs:
kwargs["atomic"] = False

self._checkpoint_slots = defaultdict(list)

super(TrainsSaver, self).__init__(dirname=dirname, *args, **kwargs)

@idist.one_rank_only()
Expand All @@ -659,32 +662,92 @@ def _setup_check_trains(self, logger, output_uri):
if output_uri:
self._task.output_uri = output_uri

class _CallbacksContext:
def __init__(
self,
callback_type: Type[Enum],
slots: List,
checkpoint_key: str,
filename: str,
basename: str,
metadata: Optional[Mapping] = None,
):
self._callback_type = callback_type
self._slots = slots
self._checkpoint_key = str(checkpoint_key)
self._filename = filename
self._basename = basename
self._metadata = metadata

def pre_callback(self, action: str, model_info: Any):
if action != self._callback_type.save:
return model_info

try:
slot = self._slots.index(None)
self._slots[slot] = model_info.upload_filename
except ValueError:
self._slots.append(model_info.upload_filename)
slot = len(self._slots) - 1

model_info.upload_filename = "{}_{}{}".format(self._basename, slot, os.path.splitext(self._filename)[1])
model_info.local_model_id = "{}:{}".format(self._checkpoint_key, model_info.upload_filename)
return model_info

def post_callback(self, action: str, model_info: Any):
if action != self._callback_type.save:
return model_info

model_info.model.name = "{}: {}".format(model_info.task.name, self._filename)
prefix = "Checkpoint Metadata: "
metadata = "{}{}".format(
prefix,
", ".join("{}={}".format(k, v) for k, v in self._metadata.items()) if self._metadata else "none",
)
comment = "\n".join(
metadata if line.startswith(prefix) else line for line in (model_info.model.comment or "").split("\n")
)
if prefix not in comment:
comment += "\n" + metadata
model_info.model.comment = comment

return model_info

def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None:
super(TrainsSaver, self).__call__(checkpoint, filename, metadata)

if idist.get_rank() == 0:
# Maybe wont work with XLA
if self._atomic:
try:
import trains
except ImportError:
raise RuntimeError(
"This contrib module requires trains to be installed. "
"You may install trains using: \n pip install trains \n"
)

# If atomic, DiskSaver's implementation first stores checkpoint into a temporary file
# and prohibits trains to automatically detect correct artifact path and name
path = os.path.join(self.dirname, filename)
if os.path.exists(path):
trains.binding.frameworks.WeightsFileHandler.create_output_model(
model=checkpoint,
saved_path=path,
framework=trains.model.Framework.pytorch,
task=self._task,
singlefile=True,
model_name=os.path.basename(filename),
)
try:
from trains import Model
from trains.binding.frameworks import WeightsFileHandler
except ImportError:
raise RuntimeError(
"This contrib module requires trains to be installed. "
"You may install trains using: \n pip install trains \n"
)

try:
basename = metadata["basename"]
except (TypeError, KeyError):
warnings.warn("Checkpoint metadata missing or basename cannot be found")
basename = "checkpoint"

checkpoint_key = (self.dirname, basename)

cb_context = self._CallbacksContext(
callback_type=WeightsFileHandler.CallbackType,
slots=self._checkpoint_slots[checkpoint_key],
checkpoint_key=str(checkpoint_key),
filename=filename,
basename=basename,
metadata=metadata,
)

pre_cb_id = WeightsFileHandler.add_pre_callback(cb_context.pre_callback)
post_cb_id = WeightsFileHandler.add_post_callback(cb_context.post_callback)

try:
super(TrainsSaver, self).__call__(checkpoint, filename, metadata)
finally:
WeightsFileHandler.remove_pre_callback(pre_cb_id)
WeightsFileHandler.remove_post_callback(post_cb_id)

@idist.one_rank_only()
def get_local_copy(self, filename: str) -> Optional[str]:
Expand All @@ -704,3 +767,14 @@ def get_local_copy(self, filename: str) -> Optional[str]:
if artifact:
return artifact.get_local_copy()
self._task.get_logger().report_text("Can not find artifact {}".format(filename))

@idist.one_rank_only()
def remove(self, filename: str) -> None:
super(TrainsSaver, self).remove(filename)
for slots in self._checkpoint_slots.values():
try:
slots[slots.index(filename)] = None
except ValueError:
pass
else:
break
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ mlflow
neptune-client
tensorboard
pynvml; python_version > '3.5'
trains
trains>=0.15.1
# Examples dependencies
pandas
gym
85 changes: 85 additions & 0 deletions tests/ignite/contrib/handlers/test_trains_logger.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import math
import os
from collections import defaultdict
from unittest.mock import ANY, MagicMock, Mock, call

import pytest
import torch
import trains
from trains.binding.frameworks import WeightsFileHandler
from trains.model import Framework

import ignite.distributed as idist
from ignite.contrib.handlers.trains_logger import *
Expand Down Expand Up @@ -670,6 +673,88 @@ def test_trains_disk_saver_integration_no_logger():
assert saved_files[0] == "model_1.pt"


def test_trains_saver_callbacks():
mock_task = MagicMock(spec=trains.Task)
mock_task.name = "check-task"

mock_model = MagicMock(spec=trains.OutputModel)

model_info = WeightsFileHandler.ModelInfo(
model=mock_model,
upload_filename="test.pt",
local_model_path="",
local_model_id="",
framework=Framework.pytorch,
task=mock_task,
)

mock_model_info = MagicMock(spec_set=model_info)

# Simulate 4 calls to save model and 2 to remove (n_saved=2)
filenames = [
"best_model_5_val_acc=0.123.pt",
"best_model_6_val_acc=0.234.pt",
"best_model_7_val_acc=0.356.pt",
"best_model_8_val_acc=0.456.pt",
]
metadata_list = [
{"basename": "best_model", "score_name": "val_acc", "priority": 0.123},
{"basename": "best_model", "score_name": "val_acc", "priority": 0.234},
{"basename": "best_model", "score_name": "val_acc", "priority": 0.345},
{"basename": "best_model", "score_name": "val_acc", "priority": 0.456},
]
dirname = "/tmp/test"

_checkpoint_slots = defaultdict(list)

n_saved = 2

for i, (filename, metadata) in enumerate(zip(filenames, metadata_list)):

mock_model_info.upload_filename = filename

if i >= n_saved:
# Remove
filename_to_remove = filenames[i % n_saved]
for slots in _checkpoint_slots.values():
try:
slots[slots.index(filename_to_remove)] = None
except ValueError:
pass
else:
i = i % n_saved
break

basename = metadata["basename"]
checkpoint_key = (dirname, basename)

context = TrainsSaver._CallbacksContext(
callback_type=WeightsFileHandler.CallbackType,
slots=_checkpoint_slots[checkpoint_key],
checkpoint_key=str(checkpoint_key),
filename=filename,
basename=basename,
metadata=metadata,
)

output_model_info = context.pre_callback(str(WeightsFileHandler.CallbackType.save), mock_model_info)
assert (
hasattr(output_model_info, "upload_filename")
and "{}_{}.pt".format(basename, i) in output_model_info.upload_filename
)
assert hasattr(output_model_info, "local_model_id") and str(checkpoint_key) in output_model_info.local_model_id

output_model_info = context.post_callback(str(WeightsFileHandler.CallbackType.save), mock_model_info)
assert hasattr(output_model_info, "model") and hasattr(output_model_info.model, "name")
assert hasattr(output_model_info, "model") and hasattr(output_model_info.model, "comment")
assert isinstance(output_model_info.model.name, str) and filename in output_model_info.model.name
assert (
isinstance(output_model_info.model.comment, str)
and metadata["basename"] in output_model_info.model.comment
and metadata["score_name"] in output_model_info.model.comment
)


class DummyModel(torch.nn.Module):
def __init__(self):
super(DummyModel, self).__init__()
Expand Down
0