jittor.loss3d.chamfer 源代码

# Author: Zheng-Ning Liu 
# 
# This file implements chamfer loss on both CPU and GPU.
# The implementation does no use extra NxM matrix to store distances, and thus
# supports large point clouds.

import jittor as jt
import jittor.nn as nn

cpu_src = '''
    for (int bs = 0; bs < in0_shape0; ++bs)
        for (int i = 0; i < in0_shape1; ++i) {
            float min_dis = (@in0(bs, i, 0) - @in1(bs, 0, 0)) * (@in0(bs, i, 0) - @in1(bs, 0, 0)) +
                            (@in0(bs, i, 1) - @in1(bs, 0, 1)) * (@in0(bs, i, 1) - @in1(bs, 0, 1)) +
                            (@in0(bs, i, 2) - @in1(bs, 0, 2)) * (@in0(bs, i, 2) - @in1(bs, 0, 2));
            @out(bs, i) = 0;
            for (int j = 1; j < in1_shape1; ++j) {
                float dis = (@in0(bs, i, 0) - @in1(bs, j, 0)) * (@in0(bs, i, 0) - @in1(bs, j, 0)) +
                            (@in0(bs, i, 1) - @in1(bs, j, 1)) * (@in0(bs, i, 1) - @in1(bs, j, 1)) +
                            (@in0(bs, i, 2) - @in1(bs, j, 2)) * (@in0(bs, i, 2) - @in1(bs, j, 2));
                if (dis < min_dis) {
                    min_dis = dis;
                    @out(bs, i) = j;
                }
            }
        }
'''

cuda_src = '''
    __global__ void chamfer_loss_min_idx_kernel(@ARGS_DEF) {
        @PRECALC
        int bs = blockIdx.x;
        int n = in0_shape1;
        int m = in1_shape1;

        for (int i = threadIdx.x; i < n; i += blockDim.x) {
            float min_dis = (@in0(bs, i, 0) - @in1(bs, 0, 0)) * (@in0(bs, i, 0) - @in1(bs, 0, 0)) +
                            (@in0(bs, i, 1) - @in1(bs, 0, 1)) * (@in0(bs, i, 1) - @in1(bs, 0, 1)) +
                            (@in0(bs, i, 2) - @in1(bs, 0, 2)) * (@in0(bs, i, 2) - @in1(bs, 0, 2));
            @out(bs, i) = 0;
            for (int j = 1; j < m; ++j) {
                float dis = (@in0(bs, i, 0) - @in1(bs, j, 0)) * (@in0(bs, i, 0) - @in1(bs, j, 0)) +
                            (@in0(bs, i, 1) - @in1(bs, j, 1)) * (@in0(bs, i, 1) - @in1(bs, j, 1)) +
                            (@in0(bs, i, 2) - @in1(bs, j, 2)) * (@in0(bs, i, 2) - @in1(bs, j, 2));
                if (dis < min_dis) {
                    min_dis = dis;
                    @out(bs, i) = j;
                }
            }
        }
    }

    chamfer_loss_min_idx_kernel<<<in0_shape0, 512>>>(@ARGS);
'''


[文档] def chamfer_loss(pc1, pc2, reduction='mean', dims='BNC', bidirectional=False): ''' return the chamfer loss from pc1 to pc2. :param pc1: input point cloud :type pc1: jittor array :param pc2: input point cloud :type pc2: jittor array :param reduction: reduction method in batches, can be 'mean', 'sum', or None. Default: 'mean'. :type reduction: str, optional :param dims: a string that represents each dimension, can be '[BNC]' ([batch, number of points, xyz]), or '[BCN]' ([batch, xyz, number of points]). Default: 'BNC'. :type dims: str, optional Example: >>> import jittor as jt >>> from jittor.loss3d import chamfer_loss >>> jt.flags.use_cuda = True >>> pc1 = jt.rand([10, 100, 3], dtype=jt.float32) >>> pc2 = jt.rand([10, 100, 3], dtype=jt.float32) >>> cf = chamfer_loss(pc1, pc2, dims='BNC', bidirectional=True) >>> print('chamfer loss =', cf.item()) ''' if bidirectional: return chamfer_loss(pc1, pc2, reduction, dims) + chamfer_loss(pc2, pc1, reduction, dims) assert dims in ['BNC', 'BCN'] if dims == 'BCN': pc1, pc2 = pc1.permute(0, 2, 1), pc2.permute(0, 2, 1) batch_size_1, N, _ = pc1.shape batch_size_2, M, _ = pc2.shape assert batch_size_1 == batch_size_2 batch_size = batch_size_1 idx = jt.code([batch_size, N], 'int32', [pc1, pc2], cpu_src=cpu_src, cuda_src=cuda_src) nearest_pts = pc2.reindex([batch_size, idx.shape[1], 3], [ 'i0', '@e0(i0, i1)', 'i2' ], extras=[idx]) chamfer_distance = (((pc1 - nearest_pts) ** 2).sum(dim=-1)).sqrt() if reduction is None: return chamfer_distance elif reduction == 'sum': return jt.sum(chamfer_distance) elif reduction == 'mean': return jt.mean(chamfer_distance)
[文档] class ChamferLoss(nn.Module): ''' A loss layer that computes the chamfer loss from pc1 to pc2. :param pc1: input point cloud :type pc1: jittor array :param pc2: input point cloud :type pc2: jittor array :param reduction: reduction method in batches, can be 'mean', 'sum', or None. Default: 'mean'. :type reduction: str, optional :param dims: a string that represents each dimension, can be '[BNC]' ([batch, number of points, xyz]), or '[BCN]' ([batch, xyz, number of points]). Default: 'BNC'. :type dims: str, optional Example: >>> import jittor as jt >>> from jittor.loss3d import ChamferLoss >>> jt.flags.use_cuda = True >>> pc1 = jt.rand([10, 100, 3], dtype=jt.float32) >>> pc2 = jt.rand([10, 100, 3], dtype=jt.float32) >>> CF = ChamferLoss(dims='BNC', bidirectional=True) >>> cf = CF(pc1, pc2) >>> print('chamfer loss =', cf.item()) ''' def __init__(self, reduction='mean', dims='BNC', bidirectional=False): ''' see function @chamfer_loss ''' super().__init__() self.reduction = reduction self.dims = dims self.bidirectional = bidirectional def execute(self, pc1, pc2): return chamfer_loss(pc1, pc2, self.reduction, self.dims, self.bidirectional)