jittor.mpi

这里是Jittor的MPI模块的API文档,您可以通过from jittor import mpi来获取该模块。

如何从单卡代码适配多卡代码

使用mpirun时,以下几种模块会自动检测mpi环境并且自动切换成多卡版本:

  • jittor.optimizer: 自动同步梯度

  • jittor.nn.BatchNorm*: 同步batch norm

  • jittor.dataset: 自动数据并行

大部分情况下,单卡训练的代码可以直接使用mpirun实现分布式多卡运行。 但仍然如下几种情况下,需要对代码进行调整:

  1. 对硬盘进行写操作(保存模型,保存曲线)

  2. 需要统计全局信息(validation 上的全局准确率)

对硬盘进行写操作

对于第一点,假设原来您的代码如下:

for i, (images, labels) in enumerate(dataset):
    output = model(images)
    loss = nn.cross_entropy_loss(output, labels)
    acc1 = accuracy(output, labels)
    SGD.step(loss)
    loss_data = loss.data
    writer.add_scalar("Train/loss")

更改后的代码如下:

for i, (images, labels) in enumerate(dataset):
    output = model(images)
    loss = nn.cross_entropy_loss(output, labels)
    acc1 = accuracy(output, labels)
    SGD.step(loss)
    loss_data = loss.data
    if jt.rank == 0:
        writer.add_scalar("Train/loss")

这里我们使用了 jt.rank 来限制,只允许第一个进程可以写 loss,这个代码在单卡下也是有效的,因为单卡的 jt.rank 值为 0, 需要注意的是,在 if jt.rank == 0 代码块里面的代码,不允许调用任何jittor的api,因为这很有可能导致多卡之间的api调用不一致而产生死锁!

需要统计全局信息

统计全局信息有两种方法,第一种是使用提供的 mpi op 来实现全局信息统计, 如下所示, 是一个validation的代码:

def val(epoch):
    global min_error
    model.eval()
    correct_nums = 0
    for i, (images, labels) in enumerate(valdataset):
        output = model(images)
        correct_nums += top1error(output, labels)
        correct_nums.sync()
    top1_error = (valdataset.total_len - correct_nums.data[0]) / valdataset.total_len
    if top1_error < min_error:
        print("[*] Best model is updated ...")
        model.save('model_best.pkl')

更改方案如下:

def val(epoch):
    global min_error
    model.eval()
    correct_nums = 0
    for i, (images, labels) in enumerate(valdataset):
        output = model(images)
        correct_nums += top1error(output, labels)
        correct_nums.sync()
    if jt.in_mpi:
        correct_nums = correct_nums.mpi_all_reduce()
    top1_error = (valdataset.total_len - correct_nums.data[0]) / valdataset.total_len
    if jt.rank == 0 and top1_error < min_error:
        print("[*] Best model is updated ...")
        model.save('model_best.pkl')

可以留意到我们首先使用了 mpi_all_reduce, 来统计多卡的正确数量(mpi_all_reduce会将多个mpi进程的结果累加起来), 然后在 jt.rank == 0 的情况下才更新模型。

第二种方法是使用@jt.single_process_scope(),被装饰的代码会直接以单进程的方式执行,无需处理多卡。

@jt.single_process_scope()
def val(epoch):
    ......

下面是 jittor 的 mpi api reference.

jittor_mpi_core.broadcast()

Document: *

Use jt.Module.mpi_param_broadcast(root=0) to broadcast all moudule parameters of this module in [root] MPI node to all MPI nodes.

This operation has no gradient, and the input parameter type is numpy array.

Declaration: void _mpi_broadcast(ArrayArgs&& args, int root)

jittor_mpi_core.get_state()

Document: *

Get MPI state, enable or disable.

Declaration: inline int _mpi_get_state()

jittor_mpi_core.local_rank()

Document: * Return local ID of this MPI node.

Declaration: int _mpi_local_rank()

jittor_mpi_core.set_state()

Document: *

Set MPI state, enable or disable, if disabled, all mpi operators have no affect.

Declaration: inline void _mpi_set_state(bool enable)

jittor_mpi_core.world_rank()

Document: * Return global ID of this MPI node.

Declaration: int _mpi_world_rank()

jittor_mpi_core.world_size()

Document: * Return number of MPI nodes.

Declaration: int _mpi_world_size()

jittor_mpi_core.ops.mpi_all_reduce()

Document: *

Mpi All Reduce Operator uses the operator [op] to reduce variable [x] in all MPI nodes and broadcast to all MPI nodes.

Args:

  • x: variable to be all reduced.

  • op: ‘sum’ or ‘add’ means sum all [x], ‘mean’ means average all [x]. Default: ‘add’.

Declaration: VarHolder* mpi_all_reduce(VarHolder* x, NanoString op=ns_add)

jittor_mpi_core.ops.mpi_broadcast()

Document: *

Mpi Broadcast Operator broadcasts variable [x] in [root] MPI nodes to all MPI nodes.

Args:

  • x: variable to be broadcasted.

  • root: ID of MPI node to be broadcasted. Default: 0.

Declaration: VarHolder* mpi_broadcast(VarHolder* x, int root=0)

jittor_mpi_core.ops.mpi_reduce()

Document: *

Mpi Reduce Operator uses the operator [op] to reduce variable [x] in all MPI nodes and send to the [root] MPI node.

Args:

  • x: variable to be reduced.

  • op: ‘sum’ or ‘add’ means sum all [x], ‘mean’ means average all [x]. Default: ‘add’.

  • root: ID of MPI node to output. Default: 0.

Declaration: VarHolder* mpi_reduce(VarHolder* x, NanoString op=ns_add, int root=0)

jittor_mpi_core.ops.mpi_test()

Declaration: VarHolder* mpi_test(string cmd)