8000 Added input data type check by vfdev-5 · Pull Request #1301 · pytorch/ignite · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Added input data type check #1301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,9 @@ def switch_batch(engine):
"Please, use torch.manual_seed or ignite.utils.manual_seed"
)

if not isinstance(data, Iterable):
raise TypeError("Argument data should be iterable")

if self.state.max_epochs is not None:
# Check and apply overridden parameters
if max_epochs is not None:
Expand Down
18 changes: 14 additions & 4 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,29 @@ def test_terminate():


def test_invalid_process_raises_with_invalid_signature():
with pytest.raises(ValueError):
with pytest.raises(ValueError, match=r"Engine must be given a processing function in order to run"):
Engine(None)

with pytest.raises(ValueError):
with pytest.raises(ValueError, match=r"Error adding .+ takes parameters .+ but will be called with"):
Engine(lambda: None)

with pytest.raises(ValueError):
with pytest.raises(ValueError, match=r"Error adding .+ takes parameters .+ but will be called with"):
Engine(lambda batch: None)

with pytest.raises(ValueError):
with pytest.raises(ValueError, match=r"Error adding .+ takes parameters .+ but will be called with"):
Engine(lambda engine, batch, extra_arg: None)


def test_invalid_input_data():
engine = Engine(lambda e, b: None)

def data():
pass

with pytest.raises(TypeError, match=r"Argument data should be iterable"):
engine.run(data)


def test_current_epoch_counter_increases_every_epoch():
engine = Engine(MagicMock(return_value=1))
max_epochs = 5
Expand Down
0