Jittor性能测试与对比方法

下面代码以AlexNet为例,用于演示 Jittor 性能测试的正确方法:

import time
import jittor as jt
from jittor.models import resnet50
jt.flags.use_cuda = jt.has_cuda

warmup = 10
rerun = 100
batch_size = 8
data = jt.random((batch_size, 3, 224, 224))
model = resnet50()
model.eval()

# 此段代码对jittor进行热身,确保时间测试准确
jt.sync_all(True)
for i in range(warmup):
    pred = model(data)
    # sync是把计算图发送到计算设备上
    pred.sync()
# sync_all(true)是把计算图发射到计算设备上,并且同步。
# 只有运行了jt.sync_all(True)才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话
jt.sync_all(True)

# 开始测试运行时间
start = time.time()
for i in range(rerun):
    pred = model(data)
    pred.sync()
jt.sync_all(True)
end = time.time()

print("Jittor FPS:", (rerun*batch_size)/(end-start))

在这段代码中,我们定义了几个参数batch_size, warmup, rerun, batch_size代表批大小,warmup是用于热身的循环次数,而rerun是用于测速的循环次数,最终输出FPS,对Jittor进行正确测速的关键是 热身部分和同步部分,热身部分确保测试时间稳定,没有包含编译用的时间,而同步部分确保计算完成,因为jittor是一个异步框架,只有同步操作能保证计算完成。

以上代码的运行结果如下(RTX Titan,batch 8):

Compiling Operators(8/8) used: 7.35s eta:    0s
Compiling Operators(13/13) used: 8.36s eta:    0s
Jittor FPS: 908.9853866375396

我们还可以使用类似的代码测试 PyTorch的性能:

import time
import torch
from torchvision.models import resnet50

warmup = 10
rerun = 100
batch_size = 8
data = torch.randn((batch_size, 3, 224, 224)).cuda()
model = resnet50()
model.cuda()
model.eval()

# 此段代码对pytorch进行热身,确保时间测试准确
torch.cuda.synchronize()
for i in range(warmup):
    pred = model(data)
# synchronize用于确保PyTorch计算完成
torch.cuda.synchronize()

# 开始测试运行时间
start = time.time()
for i in range(rerun):
    pred = model(data)
torch.cuda.synchronize()
end = time.time()

print("PyTorch FPS:", (rerun*batch_size)/(end-start))

以上代码的运行结果如下(RTX Titan,batch 8):

PyTorch FPS: 807.4806873965665

我们还可以对这两段代码合并,并对比结果的一致性:

import time
import jittor as jt
from jittor.models import resnet50
jt.flags.use_cuda = jt.has_cuda

warmup = 100
rerun = 1000
batch_size = 8
data = jt.random((batch_size, 3, 224, 224))
model = resnet50()
model.eval()

# 此段代码对jittor进行热身,确保时间测试准确
jt.sync_all(True)
for i in range(warmup):
    pred = model(data)
    # sync是把计算图发送到计算设备上
    pred.sync()
# sync_all(true)是把计算图发射到计算设备上,并且同步。
# 只有运行了jt.sync_all(True)才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话
jt.sync_all(True)

# 开始测试运行时间
start = time.time()
for i in range(rerun):
    pred = model(data)
    pred.sync()
jt.sync_all(True)
end = time.time()

print("Jittor FPS:", (rerun*batch_size)/(end-start))
# 将 jittor 数据和参数导出为 numpy 和 torch 格式
jittor_data = pred.numpy()
jittor_param = model.state_dict(to="torch")

import numpy as np
import torch
from torchvision.models import resnet50
data = torch.Tensor(data.numpy()).cuda()
model = resnet50()
# 加载 jittor 参数
model.load_state_dict(jittor_param)
model.cuda()
model.eval()

# 此段代码对pytorch进行热身,确保时间测试准确
torch.cuda.synchronize()
for i in range(warmup):
    pred = model(data)
# synchronize用于确保PyTorch计算完成
torch.cuda.synchronize()

# 开始测试运行时间
start = time.time()
for i in range(rerun):
    pred = model(data)
torch.cuda.synchronize()
end = time.time()

print("PyTorch FPS:", (rerun*batch_size)/(end-start))
pytorch_data = pred.detach().cpu().numpy()
err = np.mean(np.abs(pytorch_data - jittor_data))
print("mean error:", err)

以上代码运行结果如下:

Jittor FPS: 908.9853866375396
PyTorch FPS: 807.4806873965665
mean error: 1e-5

误差输出为1e-5, 在可接受范围内。正确测速与对比的几大关键点为:

  1. 充分热身,除去框架的准备时间。

  2. 多次运行,确保测试时间稳定。

  3. 加上同步语句,确保测试时间准确。

  4. 保证显存充足,在显存不足时,jittor会调用统一内存来弥补,会产生性能损失,请密切关注nvidia-smi的输出结果。

  5. 保证对比模型的一致性,检查输出结果的一致。

如果您对测试结果有疑问,或者有优化需求,欢迎随时联系Jittor开发团队。