2020/03/17

元算子：通过元算子实现自己的卷积层

``````def matmul(a, b):
(n, m), k = a.shape, b.shape[-1]
return (a*b).sum(dim=1)
``````

``````import numpy as np
import os
def conv_naive(x, w):
N,H,W,C = x.shape

Kh, Kw, _C, Kc = w.shape
assert C==_C, (x.shape, w.shape)
y = np.zeros([N,H-Kh+1,W-Kw+1,Kc])
for i0 in range(N):
for i1 in range(H-Kh+1):
for i2 in range(W-Kw+1):
for i3 in range(Kh):
for i4 in range(Kw):
for i5 in range(C):
for i6 in range(Kc):
if i1-i3<0 or i2-i4<0 or i1-i3>=H or i2-i4>=W: continue
y[i0, i1, i2, i6] += x[i0, i1 + i3, i2 + i4, i5] * w[i3,i4,i5,i6]
return y
``````

``````# %matplotlib inline
import pylab as pl
img_path="/tmp/cat.jpg"
if not os.path.isfile(img_path):
!wget -O - 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/4f/Felis_silvestris_catus_lying_on_rice_straw.jpg/220px-Felis_silvestris_catus_lying_on_rice_straw.jpg' > \$img_path
pl.subplot(121)
pl.imshow(img)
kernel = np.array([
[-1, -1, -1],
[0, 0, 0],
[1, 1, 1],
])
pl.subplot(122)
x = img[np.newaxis,:,:,:1].astype("float32")
w = kernel[:,:,np.newaxis,np.newaxis].astype("float32")
y = conv_naive(x, w)
print (x.shape, y.shape) # shape exists confusion
pl.imshow(y[0,:,:,0])
``````

``````import jittor as jt

def conv(x, w):
N,H,W,C = x.shape
Kh, Kw, _C, Kc = w.shape
assert C==_C
xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [
'i0', # Nid
'i1+i3', # Hid+Khid
'i2+i4', # Wid+KWid
'i5', # Cid|
])
yy = xx*ww
y = yy.sum([3,4,5]) # Kh, Kw, c
return y

# Let's disable tuner. This will cause jittor not to use mkl for convolution
jt.flags.enable_tuner = 0

jx = jt.array(x)
jw = jt.array(w)
jy = conv(jx, jw).fetch_sync()
print (jx.shape, jy.shape)
pl.imshow(jy[0,:,:,0])
``````

``````%time y = conv_naive(x, w)
%time jy = conv(jx, jw).fetch_sync()
``````

``````help(jt.reindex)
``````

``````xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [
'i0', # Nid
'i1+i3', # Hid+Khid
'i2+i4', # Wid+KWid
'i5', # Cid
])
yy = xx*ww
y = yy.sum([3,4,5]) # Kh, Kw, c
``````

``````shape = [N,H+Kh-1,W+Kw-1,Kh,Kw,C,Kc]
# expansion of x.reindex
xx = np.zeros(shape, x.dtype)
for i0 in range(shape[0]):
for i1 in range(shape[1]):
for i2 in range(shape[2]):
for i3 in range(shape[3]):
for i4 in range(shape[4]):
for i5 in range(shape[5]):
for i6 in range(shape[6]):
if is_overflow(i0,i1,i2,i3,i4,i5,i6):
xx[i0,i1,...,in] = 0
else:
xx[i0,i1,i2,i3,i4,i5,i6] = x[i0,i1+i3,i2+i4,i5]
ww = np.zeros(shape, x.dtype)
for i0 in range(shape[0]):
for i1 in range(shape[1]):
for i2 in range(shape[2]):
for i3 in range(shape[3]):
for i4 in range(shape[4]):
for i5 in range(shape[5]):
for i6 in range(shape[6]):
ww[i0,i1,i2,i3,i4,i5,i6] = w[i3,i4,i5,i6]
# expansion of xx*ww
yy = np.zeros(shape, x.dtype)
for i0 in range(shape[0]):
for i1 in range(shape[1]):
for i2 in range(shape[2]):
for i3 in range(shape[3]):
for i4 in range(shape[4]):
for i5 in range(shape[5]):
for i6 in range(shape[6]):
yy[i0,i1,i2,i3,i4,i5,i6] = xx[i0,i1,i2,i3,i4,i5,i6] * ww[i0,i1,i2,i3,i4,i5,i6]
# expansion of yy.sum([3,4,5])
shape2 = [N,H-Kh+1,W-Kw+1,Kc]
y = np.zeros(shape2, x.dtype)
for i0 in range(shape[0]):
for i1 in range(shape[1]):
for i2 in range(shape[2]):
for i3 in range(shape[3]):
for i4 in range(shape[4]):
for i5 in range(shape[5]):
for i6 in range(shape[6]):
y[i0,i1,i2,i6] += yy[i0,i1,i2,i3,i4,i5,i6]
``````

``````shape2 = [N,H-Kh+1,W-Kw+1,Kc]
y = np.zeros(shape2, x.dtype)
for i0 in range(shape[0]):
for i1 in range(shape[1]):
for i2 in range(shape[2]):
for i3 in range(shape[3]):
for i4 in range(shape[4]):
for i5 in range(shape[5]):
for i6 in range(shape[6]):
if not is_overflow(i0,i1,i2,i3,i4,i5,i6):
y[i0,i1,i2,i6] += x[i0,i1+i3,i2+i4,i5] * w[i3,i4,i5,i6]
``````

jittor会尝试将融合算子优化得尽可能快。 让我们尝试一些优化（将形状作为常量编译到内核中），并编译到底层的c++内核代码中。

``````jt.flags.compile_options={"compile_shapes":1}
with jt.profile_scope() as report:
jy = conv(jx, jw).fetch_sync()
jt.flags.compile_options={}

print(f"Time: {float(report[1][4])/1e6}ms")

with open(report[1][1], 'r') as f: