8000 Added support of Horovod by vfdev-5 · Pull Request #1195 · pytorch/ignite · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Added support of Horovod #1195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
035ed48
[WIP] Horovod comp model
vfdev-5 Jul 3, 2020
8977bb9
[WIP] Horovod comp model
vfdev-5 Jul 7, 2020
40591dd
Refactored test_utils.py into 3 files
vfdev-5 Jul 7, 2020
a9232d7
Merge branch 'refactor-idist-test-utils' into idist-hvd
vfdev-5 Jul 7, 2020
a865b4e
Merge branch 'master' of github.com:pytorch/ignite into idist-hvd
vfdev-5 Jul 8, 2020
f51399a
[WIP] Run horovod tests
vfdev-5 Jul 9, 2020
7c22c58
[WIP] Horovod comp model + tests
vfdev-5 Jul 11, 2020
7ecaf41
autopep8 fix
Jul 11, 2020
d979b42
Merge branch 'master' into idist-hvd
vfdev-5 Jul 11, 2020
5c8fd9a
[WIP] More tests
vfdev-5 Jul 11, 2020
ede1693
Updated utils tests
vfdev-5 Jul 11, 2020
6e11541
autopep8 fix
Jul 11, 2020
66a7f8b
[WIP] more tests
vfdev-5 Jul 12, 2020
0b075a6
Updated tests and code and cifar10 example
vfdev-5 Jul 22, 2020
a902a3c
autopep8 fix
Jul 22, 2020
9dca852
Merge branch 'master' of github.com:pytorch/ignite into idist-hvd
vfdev-5 Jul 22, 2020
0b4c39b
Fixed failing CI and updated code
vfdev-5 Jul 22, 2020
4d688c9
autopep8 fix
Jul 22, 2020
48e2b09
Fixes failing test
vfdev-5 Jul 22, 2020
81fc117
Merge branch 'idist-hvd' of github.com:pytorch/ignite into idist-hvd
vfdev-5 Jul 22, 2020
df0890e
Fixed bug with new/old hvd API and the config
vfdev-5 Jul 23, 2020
92ec040
Added metric tests
vfdev-5 Jul 23, 2020
4c59648
Formatting and docs updated
vfdev-5 Jul 23, 2020
5c16a8c
Updated frequency test
vfdev-5 Aug 2, 2020
4dcb230
Fixed formatting and a typo in idist.model_name docs
vfdev-5 Aug 2, 2020
862e752
Fixed failing test
vfdev-5 Aug 2, 2020
75d02e3
Merge branch 'master' of github.com:pytorch/ignite into idist-hvd
vfdev-5 Aug 2, 2020
598cbc0
Docs updates and updated auto methods according to horovod API
vfdev-5 Aug 2, 2020
840950e
autopep8 fix
Aug 2, 2020
5752981
Cosmetics
vfdev-5 Aug 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ parameters:
pytorch_stable_image:
type: string
# https://hub.docker.com/r/pytorch/pytorch/tags
default: "pytorch/pytorch:1.5-cuda10.1-cudnn7-runtime"
default: "pytorch/pytorch:1.5.1-cuda10.1-cudnn7-runtime"
pytorch_stable_image_devel:
type: string
# https://hub.docker.com/r/pytorch/pytorch/tags
default: "pytorch/pytorch:1.5.1-cuda10.1-cudnn7-devel"
workingdir:
type: string
default: "/tmp/ignite"
Expand Down Expand Up @@ -40,6 +44,12 @@ pull_pytorch_stable_image: &pull_pytorch_stable_image
command: |
docker pull << pipeline.parameters.pytorch_stable_image >>

pull_pytorch_stable_devel_image: &pull_pytorch_stable_devel_image
- run:
name: Pull PyTorch Stable Develop Image
command: |
docker pull << pipeline.parameters.pytorch_stable_image_devel >>


run_pytorch_container: &run_pytorch_container
- run:
Expand All @@ -51,6 +61,17 @@ run_pytorch_container: &run_pytorch_container
docker exec -it pthd nvidia-smi
docker exec -it pthd ls


run_pytorch_devel_container: &run_pytorch_devel_container
- run:
name: Start Pytorch dev container
environment:
wd: << pipeline.parameters.workingdir >>
command: |
docker run --gpus=all --rm -itd --shm-size 16G -v ${wd}:/ignite -w /ignite --name pthd << pipeline.parameters.pytorch_stable_image_devel >>
docker exec -it pthd nvidia-smi
docker exec -it pthd ls

install_dependencies: &install_dependencies
- run:
name: Install dependencies
Expand Down Expand Up @@ -194,6 +215,68 @@ jobs:
docker exec -it pthd /bin/bash -c "${test_cmd} --num_epochs=7 ${resume_opt}"


two_gpus_hvd_tests:
<<: *two_gpus

working_directory: << pipeline.parameters.workingdir >>

steps:
- checkout
- <<: *pull_pytorch_stable_devel_image
- <<: *run_pytorch_devel_container
- <<: *install_dependencies
- run:
name: "Install Horovod with NCCL GPU ops"
command: |

# Following https://github.com/horovod/horovod/blob/master/Dockerfile.test.gpu
# and https://github.com/horovod/horovod/issues/1944#issuecomment-628192778
docker exec -it pthd /bin/bash -c "apt-get update && apt-get install -y git"
docker exec -it pthd /bin/bash -c "git clone --recursive https://github.com/horovod/horovod.git /horovod && cd /horovod && python setup.py sdist"
docker exec -it pthd /bin/bash -c "conda install -y cmake=3.16 nccl=2.5 -c conda-forge"
docker exec -it pthd /bin/bash -c 'cd /horovod && HOROVOD_GPU_OPERATIONS=NCCL HOROVOD_NCCL_LINK=SHARED HOROVOD_WITHOUT_MPI=1 HOROVOD_WITH_PYTORCH=1 pip install -v $(ls /horovod/dist/horovod-*.tar.gz) && ldconfig'
docker exec -it pthd horovodrun --check-build

- run:
name: Run 1 Node 2 GPUs Unit Tests
command: |
export test_cmd='sh tests/run_gpu_tests.sh'
docker exec -it pthd /bin/bash -c "${test_cmd}"
# no CUDA devices Horovod tests
export test_cmd='CUDA_VISIBLE_DEVICES="" py.test --cov ignite --cov-append --cov-report term-missing --cov-report xml -vvv tests/ -m distributed'
docker exec -it pthd /bin/bash -c "${test_cmd}"

- run:
name: Codecov upload
command: |
bash <(curl -s https://codecov.io/bash) -Z -F gpu-2-hvd

- run:
name: "Check CIFAR10 using horovodrun"
command: |
docker exec -it pthd pip install fire
export example_path="examples/contrib/cifar10"
# initial run
export stop_cmd="--stop_iteration=500"
export test_cmd="cd ${example_path} && CI=1 horovodrun -np 2 python -u main.py run --backend=horovod"
docker exec -it pthd /bin/bash -c "${test_cmd} ${stop_cmd}"
# resume
export resume_opt="--resume-from=/tmp/output-cifar10/resnet18_backend-horovod-2_stop-on-500/training_checkpoint_400.pt"
docker exec -it pthd /bin/bash -c "${test_cmd} --num_epochs=7 ${resume_opt}"

- run:
name: "Check CIFAR10 using spawn"
command: |
export example_path="examples/contrib/cifar10"
# initial run
export stop_cmd="--stop_iteration=500"
export test_cmd="cd ${example_path} && CI=1 python -u main.py run --backend=horovod --nproc_per_node=2"
docker exec -it pthd /bin/bash -c "${test_cmd} ${stop_cmd}"
# resume
export resume_opt="--resume-from=/tmp/output-cifar10/resnet18_backend-horovod-2_stop-on-500/training_checkpoint_400.pt"
docker exec -it pthd /bin/bash -c "${test_cmd} --num_epochs=7 ${resume_opt}"


# -------------------------------------------------------------------------------------
# Workflows
# -------------------------------------------------------------------------------------
Expand All @@ -204,3 +287,4 @@ workflows:
- one_gpu_tests
- two_gpus_tests
- two_gpus_check_dist_cifar10_example
- two_gpus_hvd_tests
14 changes: 12 additions & 2 deletions examples/contrib/cifar10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,19 @@ or
python -u main.py run --backend="nccl" --nproc_per_node=2
```

If user would like to provide already downloaded dataset, the path can be setup in parameters as
##### Using [Horovod](https://horovod.readthedocs.io/en/latest/index.html) as distributed backend

Please, make sure to have Horovod installed before running.

Let's start training on a single node with 2 gpus:
```bash
--data_path="/path/to/cifar10/"
# horovodrun
horovodrun -np=2 python -u main.py run --backend="horovod"
```
or
```bash
# using function spawn inside the code
python -u main.py run --backend="horovod" --nproc_per_node=2
```


Expand Down
32 changes: 26 additions & 6 deletions ignite/distributed/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.utils.data.sampler import Sampler

from ignite.distributed import utils as idist
from ignite.distributed.comp_models import horovod as idist_hvd
from ignite.distributed.comp_models import native as idist_native
from ignite.distributed.comp_models import xla as idist_xla
from ignite.utils import setup_logger
Expand Down Expand Up @@ -130,6 +131,7 @@ def auto_model(model: nn.Module) -> nn.Module:
- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
- broadcast the initial variable states from rank 0 to all other processes if Horovod distributed framework is used.

Examples:

Expand Down Expand Up @@ -166,13 +168,19 @@ def auto_model(model: nn.Module) -> nn.Module:

# distributed data parallel model
if idist.get_world_size() > 1:
if idist.backend() == idist_native.NCCL:
bnd = idist.backend()
if idist.has_native_dist_support and bnd == idist_native.NCCL:
lrank = idist.get_local_rank()
logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank))
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank,])
elif idist.backend() == idist_native.GLOO:
elif idist.has_native_dist_support and bnd == idist_native.GLOO:
logger.info("Apply torch DistributedDataParallel on model")
model = torch.nn.parallel.DistributedDataParallel(model)
elif idist.has_hvd_support and bnd == idist_hvd.HOROVOD:
import horovod.torch as hvd

logger.info("Broadcast the initial variable states from rank 0 to all other processes")
hvd.broadcast_parameters(model.state_dict(), root_rank=0)

# not distributed but multiple GPUs reachable so data parallel model
elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
Expand All @@ -187,14 +195,18 @@ def auto_optim(optimizer: Optimizer) -> Optimizer:
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).

Internally, this method is no-op for non-distributed and torch native distributed configuration.

For XLA distributed configuration, we create a new class that inherits from provided optimizer.
The goal is to override the `step()` method with specific `xm.optimizer_step`_ implementation.

For Horovod distributed configuration, optimizer is wrapped with Horovod Distributed Optimizer and
its state is broadcasted from rank 0 to all other processes.

Examples:

.. code-block:: python

import ignite.distribted as idist
import ignite.distributed as idist

optimizer = idist.auto_optim(optimizer)

Expand All @@ -208,11 +220,19 @@ def auto_optim(optimizer: Optimizer) -> Optimizer:
.. _xm.optimizer_step: http://pytorch.org/xla/release/1.5/index.html#torch_xla.core.xla_model.optimizer_step

"""
if not (idist.has_xla_support and idist.backend() == idist_xla.XLA_TPU):
bnd = idist.backend()
if idist.has_xla_support and bnd == idist_xla.XLA_TPU:
cls = type(optimizer.__class__.__name__, (optimizer.__class__,), dict(_XLADistributedOptimizer.__dict__))
return cls(optimizer)

if idist.has_hvd_support and bnd == idist_hvd.HOROVOD:
import horovod.torch as hvd

optimizer = hvd.DistributedOptimizer(optimizer)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
return optimizer

cls = type(optimizer.__class__.__name__, (optimizer.__class__,), dict(_XLADistributedOptimizer.__dict__))
return cls(optimizer)
return optimizer


class DistributedProxySampler(DistributedSampler):
Expand Down
5 changes: 5 additions & 0 deletions ignite/distributed/comp_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ignite.distributed.comp_models.base import _SerialModel
from ignite.distributed.comp_models.horovod import has_hvd_support
from ignite.distributed.comp_models.native import has_native_dist_support
from ignite.distributed.comp_models.xla import has_xla_support

Expand All @@ -15,6 +16,10 @@ def setup_available_computation_models():
from ignite.distributed.comp_models.xla import _XlaDistModel

models.append(_XlaDistModel)
if has_hvd_support:
from ignite.distributed.comp_models.horovod import _HorovodDistModel

models.append(_HorovodDistModel)

return tuple(models)

Expand Down
2 changes: 1 addition & 1 deletion ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class _SerialModel(ComputationModel):
"""

name = "serial"
available_backends = tuple()
available_backends = ()

def get_local_rank(self) -> int:
return 0
Expand Down
Loading
0