MXNet支持分布式训练,可以通过Horovod或Parameter Server来实现。
使用Horovod进行分布式训练的步骤如下:
pip install horovod
import mxnet as mx
import horovod.mxnet as hvd
hvd.init()
train_data = mx.io.ImageRecordIter(...)
train_data = hvd.DistributedDataLoader(train_data)
net = mx.gluon.nn.Sequential()
net.add(mx.gluon.nn.Dense(128))
net.add(mx.gluon.nn.Activation('relu'))
net.add(mx.gluon.nn.Dense(10))
net.initialize()
opt = mx.optimizer.SGD(learning_rate=0.1)
opt = hvd.DistributedOptimizer(opt)
with mx.gluon.utils.split_and_load(data, ctx_list=hvd.local_devices()):
...
使用Parameter Server进行分布式训练的步骤如下:
pip install mxnet
import mxnet as mx
from mxnet import kv
num_workers = 2
ps = kv.create('dist')
net = mx.gluon.nn.Sequential()
net.add(mx.gluon.nn.Dense(128))
net.add(mx.gluon.nn.Activation('relu'))
net.add(mx.gluon.nn.Dense(10))
net.initialize()
opt = mx.optimizer.SGD(learning_rate=0.1)
opt = kv.DistributedOptimizer(opt)
with mx.autograd.record():
...