๊ธฐ์ฌ์ ๋ฒ์ญ์ ๊ณผ์ ์์ ์ ๋ ์ ์ค๋น๋์์ต๋๋ค.
์ฌ๋ฌ ๊ณ ์ฑ๋ฅ ์ปดํจํ ์ธ์คํด์ค์ ๋ํ ๋ถ์ฐ ํ๋ จ์ ๋๋์ ๋ฐ์ดํฐ์ ๋ํ ํ๋ ์ฌ์ธต ์ ๊ฒฝ๋ง์ ํ๋ จ ์๊ฐ์ ๋ช ์ฃผ์์ ๋ช ์๊ฐ, ์ฌ์ง์ด ๋ช ๋ถ์ผ๋ก ๋จ์ถํ ์ ์์ผ๋ฏ๋ก ์ด ํ๋ จ ๊ธฐ์ ์ ๋ฅ ๋ฌ๋์ ์ค์ ์์ฉ ๋ถ์ผ์์ ๋๋ฆฌ ๋ณด๊ธ๋ฉ๋๋ค. ์ฌ์ฉ์๋ ์ฌ๋ฌ ์ธ์คํด์ค์์ ๋ฐ์ดํฐ๋ฅผ ๊ณต์ ํ๊ณ ๋๊ธฐํํ๋ ๋ฐฉ๋ฒ์ ์ดํดํด์ผ ํ๋ฉฐ, ์ด๋ ๊ฒฐ๊ณผ์ ์ผ๋ก ํ์ฅ ํจ์จ์ฑ์ ํฐ ์ํฅ์ ๋ฏธ์นฉ๋๋ค. ๋ํ ์ฌ์ฉ์๋ ๋จ์ผ ์ธ์คํด์ค์์ ์คํ๋๋ ๊ต์ก ์คํฌ๋ฆฝํธ๋ฅผ ์ฌ๋ฌ ์ธ์คํด์ค์ ๋ฐฐํฌํ๋ ๋ฐฉ๋ฒ๋ ์์์ผ ํฉ๋๋ค.
์ด ๊ธฐ์ฌ์์๋ ๊ฐ๋ฐฉํ ๋ฅ ๋ฌ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ธ Apache MXNet๊ณผ Horovod ๋ถ์ฐ ํ์ต ํ๋ ์์ํฌ๋ฅผ ์ฌ์ฉํ์ฌ ํ์ต์ ๋ฐฐํฌํ๋ ๋น ๋ฅด๊ณ ์ฌ์ด ๋ฐฉ๋ฒ์ ๋ํด ์ค๋ช
ํฉ๋๋ค. Horovod ํ๋ ์์ํฌ์ ์ฑ๋ฅ ์ด์ ์ ๋ช
ํํ๊ฒ ๋ณด์ฌ์ฃผ๊ณ Horovod์ ๋ถ์ฐ ๋ฐฉ์์ผ๋ก ์๋ํ๋๋ก MXNet ๊ต์ก ์คํฌ๋ฆฝํธ๋ฅผ ์์ฑํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.
Apache MXNet์ด๋ ๋ฌด์์ ๋๊น?
๋งค๊ฐ๋ณ์ ์๋ฒ๋ฅผ ์ฌ์ฉํ๋ MXNet์ ๋ถ์ฐ ๊ต์ก
ํธ๋ก๋ณด๋๋ ๋ฌด์์ธ๊ฐ์?
MXNet ๋ฐ Horovod ํตํฉ
MXNet์ Horovod์ ์ ์๋ ๋ถ์ฐ ํ์ต API๋ฅผ ํตํด Horovod์ ํตํฉ๋ฉ๋๋ค. Horovod ํต์ API horovod.broadcast(), horovod.allgather() ะธ horovod.allreduce() ์์ ๊ทธ๋ํ์ ์ผ๋ถ๋ก MXNet ์์ง์ ๋น๋๊ธฐ ์ฝ๋ฐฑ์ ์ฌ์ฉํ์ฌ ๊ตฌํ๋ฉ๋๋ค. ์ด๋ฌํ ๋ฐฉ์์ผ๋ก MXNet ์์ง์ ํต์ ๊ณผ ๊ณ์ฐ ๊ฐ์ ๋ฐ์ดํฐ ์ข ์์ฑ์ ์ฝ๊ฒ ์ฒ๋ฆฌํ์ฌ ๋๊ธฐํ๋ก ์ธํ ์ฑ๋ฅ ์์ค์ ๋ฐฉ์งํฉ๋๋ค. Horovod์ ์ ์๋ ๋ถ์ฐ ์ต์ ํ ๊ฐ์ฒด horovod.DistributedOptimizer ํฝ์ฐฝํ๋ค ์ต์ ํ ๋ถ์ฐ ๋งค๊ฐ๋ณ์ ์ ๋ฐ์ดํธ๋ฅผ ์ํด ํด๋น Horovod API๋ฅผ ํธ์ถํ๋๋ก MXNet์์. ์ด๋ฌํ ๋ชจ๋ ๊ตฌํ ์ธ๋ถ ์ฌํญ์ ์ต์ข ์ฌ์ฉ์์๊ฒ ํฌ๋ช ํ๊ฒ ๊ณต๊ฐ๋ฉ๋๋ค.
๋น ๋ฅธ ์์
MacBook์์ MXNet ๋ฐ Horovod๋ฅผ ์ฌ์ฉํ์ฌ MNIST ๋ฐ์ดํฐ์ธํธ์์ ์๊ท๋ชจ ์ปจ๋ฒ๋ฃจ์
์ ๊ฒฝ๋ง ํ๋ จ์ ๋น ๋ฅด๊ฒ ์์ํ ์ ์์ต๋๋ค.
๋จผ์ PyPI์์ mxnet ๋ฐ horovod๋ฅผ ์ค์นํฉ๋๋ค.
pip install mxnet
pip install horovod
์ฐธ๊ณ : ๋์ค์ ์ค๋ฅ๊ฐ ๋ฐ์ํ๋ ๊ฒฝ์ฐ pip ์ค์น horovod์ด์ฉ๋ฉด ๋ณ์๋ฅผ ์ถ๊ฐํด์ผ ํ ์๋ ์์ต๋๋ค MACOSX_DEPLOYMENT_TARGET=10.vv์ด๋์์ vv โ ์ด๋ MacOS ๋ฒ์ ์ ๋ฒ์ ์ ๋๋ค. ์๋ฅผ ๋ค์ด MacOSX Sierra์ ๊ฒฝ์ฐ ๋ค์๊ณผ ๊ฐ์ด ์์ฑํด์ผ ํฉ๋๋ค. MACOSX_DEPLOYMENT_TARGET=10.12 pip ์ค์น Horovod
๊ทธ๋ฐ ๋ค์ OpenMPI๋ฅผ ์ค์นํ์ญ์์ค.
๋ง์ง๋ง์ผ๋ก ํ
์คํธ ์คํฌ๋ฆฝํธ๋ฅผ ๋ค์ด๋ก๋ํ์ญ์์ค. mxnet_mnist.py
mpirun -np 2 -H localhost:2 -bind-to none -map-by slot python mxnet_mnist.py
๊ทธ๋ฌ๋ฉด ํ๋ก์ธ์์ ๋ ์ฝ์ด์ ๋ํ ๊ต์ก์ด ์คํ๋ฉ๋๋ค. ์ถ๋ ฅ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
INFO:root:Epoch[0] Batch [0-50] Speed: 2248.71 samples/sec accuracy=0.583640
INFO:root:Epoch[0] Batch [50-100] Speed: 2273.89 samples/sec accuracy=0.882812
INFO:root:Epoch[0] Batch [50-100] Speed: 2273.39 samples/sec accuracy=0.870000
์ฑ๋ฅ ๋ฐ๋ชจ
50๊ฐ์ ์ธ์คํด์ค๊ฐ ์๋ 1๊ฐ GPU์ ImageNet ๋ฐ์ดํฐ์ธํธ์์ ResNet64-vXNUMX ๋ชจ๋ธ์ ํ๋ จํ๋ ๊ฒฝ์ฐ p3.16xlarge ๊ฐ๊ฐ AWS ํด๋ผ์ฐ๋์ 2๊ฐ์ NVIDIA Tesla V8 GPU๊ฐ ํฌํจ๋ EC100์์ ์ด๋น 45000๊ฐ ์ด๋ฏธ์ง(์ฆ, ์ด๋น ํ๋ จ๋ ์ํ ์)์ ํ๋ จ ์ฒ๋ฆฌ๋์ ๋ฌ์ฑํ์ต๋๋ค. 44๋ฒ์ ์ํฌํฌ ์ดํ 90๋ถ ๋ง์ ํ๋ จ์ด ์๋ฃ๋์์ผ๋ฉฐ ์ต๊ณ ์ ํ๋๋ 75.7%์ ๋๋ค.
์ฐ๋ฆฌ๋ ์ด๋ฅผ ๋จ์ผ ๋งค๊ฐ๋ณ์ ์๋ฒ์ ์๋ฒ ๋ ์์ ์ ๋น์จ์ด ๊ฐ๊ฐ 8:16 ๋ฐ 32:64์ธ 1, 1, 2 ๋ฐ 1 GPU์์ ๋งค๊ฐ๋ณ์ ์๋ฒ๋ฅผ ์ฌ์ฉํ๋ MXNet์ ๋ถ์ฐ ๊ต์ก ์ ๊ทผ ๋ฐฉ์๊ณผ ๋น๊ตํ์ต๋๋ค. ์๋ ๊ทธ๋ฆผ 1์์ ๊ฒฐ๊ณผ๋ฅผ ๋ณผ ์ ์์ต๋๋ค. ์ผ์ชฝ์ y์ถ์์ ๋ง๋๋ ์ด๋น ํ๋ จํ ์ด๋ฏธ์ง ์๋ฅผ ๋ํ๋ด๊ณ , ์ ์ ์ค๋ฅธ์ชฝ์ y์ถ์์ ์ค์ผ์ผ๋ง ํจ์จ์ฑ(์ฆ, ์ค์ ์ฒ๋ฆฌ๋๊ณผ ์ด์์ ์ธ ์ฒ๋ฆฌ๋์ ๋น์จ)์ ๋ฐ์ํฉ๋๋ค. ๋ณด์๋ค์ํผ ์๋ฒ ์ ์ ํ์ ํ์ฅ ํจ์จ์ฑ์ ์ํฅ์ ๋ฏธ์นฉ๋๋ค. ๋งค๊ฐ๋ณ์ ์๋ฒ๊ฐ ํ๋๋ง ์๋ ๊ฒฝ์ฐ GPU 38๊ฐ์์๋ ํ์ฅ ํจ์จ์ฑ์ด 64%๋ก ๋จ์ด์ง๋๋ค. Horovod์ ๋์ผํ ํ์ฅ ํจ์จ์ฑ์ ๋ฌ์ฑํ๋ ค๋ฉด ์์ ์ ์์ ๋นํด ์๋ฒ ์๋ฅผ ๋ ๋ฐฐ๋ก ๋๋ ค์ผ ํฉ๋๋ค.
๊ทธ๋ฆผ 1. Horovod ๋ฐ ๋งค๊ฐ๋ณ์ ์๋ฒ์ ํจ๊ป MXNet์ ์ฌ์ฉํ ๋ถ์ฐ ํ์ต ๋น๊ต
์๋ ํ 1์์๋ 64๊ฐ์ GPU์์ ์คํ์ ์คํํ ๋ ์ธ์คํด์ค๋น ์ต์ข ๋น์ฉ์ ๋น๊ตํฉ๋๋ค. Horovod์ ํจ๊ป MXNet์ ์ฌ์ฉํ๋ฉด ์ต์ ๋น์ฉ์ผ๋ก ์ต๊ณ ์ ์ฒ๋ฆฌ๋์ ์ ๊ณตํฉ๋๋ค.
ํ 1. ์๋ฒ ๋ ์์
์ ๋น์จ์ด 2:1์ธ Horovod์ Parameter Server ๊ฐ์ ๋น์ฉ ๋น๊ต.
์ฌํ ๋จ๊ณ
๋ค์ ๋จ๊ณ์์๋ MXNet ๋ฐ Horovod๋ฅผ ์ฌ์ฉํ์ฌ ๋ถ์ฐ ๊ต์ก ๊ฒฐ๊ณผ๋ฅผ ์ฌํํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ ๋๋ฆฌ๊ฒ ์ต๋๋ค. MXNet์ ํตํ ๋ถ์ฐ ํ์ต์ ๋ํด ์์ธํ ์์๋ณด๋ ค๋ฉด ์ฝ์ด๋ณด์ธ์.
1 ๋จ๊ณ
๋ถ์ฐ ํ์ต์ ์ฌ์ฉํ๋ ค๋ฉด MXNet ๋ฒ์ 1.4.0 ์ด์ ๋ฐ Horovod ๋ฒ์ 0.16.0 ์ด์์ผ๋ก ๋์ข
์ธ์คํด์ค ํด๋ฌ์คํฐ๋ฅผ ์์ฑํ์ธ์. GPU ํ๋ จ์ ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ ์ค์นํด์ผ ํฉ๋๋ค. ์ฐ๋ฆฌ์ ๊ฒฝ์ฐ GPU ๋๋ผ์ด๋ฒ 16.04, CUDA 396.44, cuDNN 9.2 ๋ผ์ด๋ธ๋ฌ๋ฆฌ, NCCL 7.2.1 ์ปค๋ฎค๋์ผ์ดํฐ ๋ฐ OpenMPI 2.2.13์ด ํฌํจ๋ Ubuntu 3.1.1 Linux๋ฅผ ์ ํํ์ต๋๋ค. ๋ํ ๋น์ ์ ์ฌ์ฉํ ์ ์์ต๋๋ค
2 ๋จ๊ณ
MXNet ๊ต์ก ์คํฌ๋ฆฝํธ์ Horovod API๋ฅผ ์ฌ์ฉํ๋ ๊ธฐ๋ฅ์ ์ถ๊ฐํ์ธ์. MXNet Gluon API๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ๋ ์๋ ์คํฌ๋ฆฝํธ๋ฅผ ๊ฐ๋จํ ํ ํ๋ฆฟ์ผ๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค. ํด๋น ํ๋ จ ์คํฌ๋ฆฝํธ๊ฐ ์ด๋ฏธ ์๋ ๊ฒฝ์ฐ ๊ตต์ ๊ธ์จ์ ์ค์ด ํ์ํฉ๋๋ค. Horovod๋ฅผ ๋ฐฐ์ฐ๊ธฐ ์ํด ์ํํด์ผ ํ ๋ช ๊ฐ์ง ์ค์ํ ๋ณ๊ฒฝ ์ฌํญ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- ํ๋ จ์ด ์ฌ๋ฐ๋ฅธ ๊ทธ๋ํฝ ์ฝ์ด์์ ์ํ๋๋ค๋ ๊ฒ์ ์ดํดํ๋ ค๋ฉด ๋ก์ปฌ Horovod ์์(8ํ)์ ๋ฐ๋ผ ์ปจํ ์คํธ๋ฅผ ์ค์ ํ์ญ์์ค.
- ๋ชจ๋ ์์ ์๊ฐ ๋์ผํ ์ด๊ธฐ ๋งค๊ฐ๋ณ์๋ก ์์ํ๋๋ก ํ๋ ค๋ฉด ํ ์์ ์์ ์ด๊ธฐ ๋งค๊ฐ๋ณ์๋ฅผ ๋ชจ๋ ์์ ์์๊ฒ ์ ๋ฌํฉ๋๋ค(18ํ).
- ํธ๋ก๋ณด๋ ๋ง๋ค๊ธฐ ๋ถ์ฐ ์ต์ ํ ํ๋ก๊ทธ๋จ (๋ผ์ธ 25) ๋ถ์ฐ ๋ฐฉ์์ผ๋ก ๋งค๊ฐ๋ณ์๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค.
์ ์ฒด ์คํฌ๋ฆฝํธ๋ฅผ ์ป์ผ๋ ค๋ฉด Horovod-MXNet ์์ ๋ฅผ ์ฐธ์กฐํ์ธ์.
1 import mxnet as mx
2 import horovod.mxnet as hvd
3
4 # Horovod: initialize Horovod
5 hvd.init()
6
7 # Horovod: pin a GPU to be used to local rank
8 context = mx.gpu(hvd.local_rank())
9
10 # Build model
11 model = ...
12
13 # Initialize parameters
14 model.initialize(initializer, ctx=context)
15 params = model.collect_params()
16
17 # Horovod: broadcast parameters
18 hvd.broadcast_parameters(params, root_rank=0)
19
20 # Create optimizer
21 optimizer_params = ...
22 opt = mx.optimizer.create('sgd', **optimizer_params)
23
24 # Horovod: wrap optimizer with DistributedOptimizer
25 opt = hvd.DistributedOptimizer(opt)
26
27 # Create trainer and loss function
28 trainer = mx.gluon.Trainer(params, opt, kvstore=None)
29 loss_fn = ...
30
31 # Train model
32 for epoch in range(num_epoch):
33 ...
3 ๋จ๊ณ
MPI ์ง์นจ์ ์ฌ์ฉํ์ฌ ๋ถ์ฐ ๊ต์ก์ ์์ํ๋ ค๋ฉด ์์ ์ ์ค ํ ๋ช ์ ๋ก๊ทธ์ธํ์ธ์. ์ด ์์์ ๋ถ์ฐ ๊ต์ก์ ๊ฐ๊ฐ 4๊ฐ์ GPU๊ฐ ์๋ 16๊ฐ์ ์ธ์คํด์ค์์ ์คํ๋๋ฉฐ ํด๋ฌ์คํฐ์๋ ์ด XNUMX๊ฐ์ GPU๊ฐ ์์ต๋๋ค. SGD(Stochastic Gradient Descent) ์ต์ ํ ํ๋ก๊ทธ๋จ์ ๋ค์ ํ์ดํผํ๋ผ๋ฏธํฐ์ ํจ๊ป ์ฌ์ฉ๋ฉ๋๋ค.
- ๋ฏธ๋ ๋ฐฐ์น ํฌ๊ธฐ: 256
- ํ์ต๋ฅ : 0.1
- ์ด๋๋: 0.9
- ์ค๋ ๊ฐ์: 0.0001
64๊ฐ์ GPU์์ 0,1๊ฐ์ GPU๋ก ํ์ฅํ๋ฉด์ GPU ์์ ๋ฐ๋ผ ํ๋ จ ์๋๋ฅผ ์ ํ์ ์ผ๋ก ํ์ฅํ์ผ๋ฉฐ(1 GPU์ ๊ฒฝ์ฐ 6,4์์ 64 GPU์ ๊ฒฝ์ฐ 256), GPU๋น ์ด๋ฏธ์ง ์๋ 256๊ฐ(๋ฐฐ์น์์)๋ก ์ ์งํ์ต๋๋ค. 1๊ฐ GPU์ ๊ฒฝ์ฐ 16๊ฐ ์ด๋ฏธ์ง๋ถํฐ 384๊ฐ GPU์ ๊ฒฝ์ฐ 64๊ฐ ์ด๋ฏธ์ง๊น์ง). GPU ์๊ฐ ์ฆ๊ฐํจ์ ๋ฐ๋ผ ๋ฌด๊ฒ ๊ฐ์ ๋ฐ ์ด๋๋ ๋งค๊ฐ๋ณ์๊ฐ ๋ณ๊ฒฝ๋์์ต๋๋ค. NVIDIA Tesla GPU์์ ์ง์ํ๋ float16 ๊ณ์ฐ ์๋๋ฅผ ๋์ด๊ธฐ ์ํด ์๋ฐฉํฅ ์ ๋ฌ์๋ float32 ๋ฐ์ดํฐ ์ ํ์, ๊ฒฝ์ฌ์๋ float16 ๋ฐ์ดํฐ ์ ํ์ ์ฌ์ฉํ๋ ํผํฉ ์ ๋ฐ๋ ๊ต์ก์ ์ฌ์ฉํ์ต๋๋ค.
$ mpirun -np 16
-H server1:4,server2:4,server3:4,server4:4
-bind-to none -map-by slot
-mca pml ob1 -mca btl ^openib
python mxnet_imagenet_resnet50.py
๊ฒฐ๋ก
์ด ๊ธฐ์ฌ์์๋ Apache MXNet ๋ฐ Horovod๋ฅผ ์ฌ์ฉํ์ฌ ๋ถ์ฐ ๋ชจ๋ธ ๊ต์ก์ ๋ํ ํ์ฅ ๊ฐ๋ฅํ ์ ๊ทผ ๋ฐฉ์์ ์ดํด๋ณด์์ต๋๋ค. ResNet50-v1 ๋ชจ๋ธ์ด ํ๋ จ๋ ImageNet ๋ฐ์ดํฐ์ธํธ์ ๋ํ ๋งค๊ฐ๋ณ์ ์๋ฒ ์ ๊ทผ ๋ฐฉ์๊ณผ ๋น๊ตํ์ฌ ํ์ฅ ํจ์จ์ฑ๊ณผ ๋น์ฉ ํจ์จ์ฑ์ ์ ์ฆํ์ต๋๋ค. ๋ํ Horovod๋ฅผ ์ฌ์ฉํ์ฌ ๋ค์ค ์ธ์คํด์ค ๊ต์ก์ ์คํํ๊ธฐ ์ํด ๊ธฐ์กด ์คํฌ๋ฆฝํธ๋ฅผ ์์ ํ๋ ๋ฐ ์ฌ์ฉํ ์ ์๋ ๋จ๊ณ๋ ํฌํจ๋์ด ์์ต๋๋ค.
MXNet๊ณผ ๋ฅ๋ฌ๋์ ์ด์ ๋ง ์์ํ์
จ๋ค๋ฉด ์ค์น ํ์ด์ง๋ก ์ด๋ํ์ธ์
์ด๋ฏธ MXNet์ ์ฌ์ฉํด๋ณธ ์ ์ด ์๊ณ Horovod๋ก ๋ถ์ฐ ํ์ต์ ์๋ํ๊ณ ์ถ๋ค๋ฉด ๋ค์์ ์ดํด๋ณด์ธ์.
*๋น์ฉ์ ๋ค์์ ๊ธฐ์ค์ผ๋ก ๊ณ์ฐ๋ฉ๋๋ค.
๊ณผ์ ์ ๋ํด ์์ธํ ์์๋ณด๊ธฐ
์ถ์ฒ : habr.com