我发现为了从stride 1 生成(X - x + 1, Y - y + 1)
大小的补丁,image 要求我们将 strides 参数设置为or 。我不知道他们如何快速计算出这个数字。conv2d 的进步(x,y)
(X,Y)
img.strides * 2
img.strides + img.strides
但是我应该怎么做才能((X-x)/stride)+1, ((Y-y)/stride)+1
从相同大小的图像中获得相同大小的补丁stride
呢?
从这个 SO答案 稍作修改,通道和图像数量放在前面
def patchify(img, patch_shape):
a,b,X, Y = img.shape # a images and b channels
x, y = patch_shape
shape = (a, b, X - x + 1, Y - y + 1, x, y)
a_str, b_str, X_str, Y_str = img.strides
strides = (a_str, b_str, X_str, Y_str, X_str, Y_str)
return np.lib.stride_tricks.as_strided(img, shape=shape, strides=strides)
我可以看到它创建了一个大小为 (x,y) 且步幅为 1 的滑动窗口(向右移动 1 个像素并向下移动 1 个像素)。我无法将as_strided
使用的步幅参数与我们通常用于 conv2d 的步幅关联起来。
如何向上述计算as_strided
strides 参数的函数添加参数?
def patchify(img, patch_shape, stride): # stride=stepsize in conv2d eg: 1,2,3,...
a,b,X,Y = img.shape # a images and b channels
x, y = patch_shape
shape = (a,b,((X-x)/stride)+1, ((Y-y)/stride)+1, x, y)
strides = ??? # strides for as_strided
return np.lib.stride_tricks.as_strided(img, shape=shape, strides=strides)
img 是 4d(a, b, X, Y)
a
=图像数量,b
=频道数,(X,Y)
= 宽度和高度
注意:stride in conv2d
我的意思stepsize
是不幸的是,这也称为步幅。
注意 2:由于stepsize
通常在两个轴上都相同,因此在我提供的代码中,我只提供了一个参数,但是将它用于两个维度。
游乐场:strides
这里
有什么用。我让它在stepsize=1
这里运行。我注意到它可能无法从链接中使用,但在粘贴到新的playground时可以使用。
这应该清楚地了解我需要什么:
[[ 0.5488135 0.71518937 0.60276338 0.54488318]
[ 0.4236548 0.64589411 0.43758721 0.891773 ]
[ 0.96366276 0.38344152 0.79172504 0.52889492]
[ 0.56804456 0.92559664 0.07103606 0.0871293 ]]
# patch_size = 2x2
# stride = 1,1
[[[[ 0.5488135 0.71518937]
[ 0.4236548 0.64589411]]
[[ 0.71518937 0.60276338]
[ 0.64589411 0.43758721]]
[[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]]]
[[[ 0.4236548 0.64589411]
[ 0.96366276 0.38344152]]
[[ 0.64589411 0.43758721]
[ 0.38344152 0.79172504]]
[[ 0.43758721 0.891773 ]
[ 0.79172504 0.52889492]]]
[[[ 0.96366276 0.38344152]
[ 0.56804456 0.92559664]]
[[ 0.38344152 0.79172504]
[ 0.92559664 0.07103606]]
[[ 0.79172504 0.52889492]
[ 0.07103606 0.0871293 ]]]]
# stride = 2,2
[[[[[[ 0.5488135 0.71518937]
[ 0.4236548 0.64589411]]
[[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]]]
[[[ 0.96366276 0.38344152]
[ 0.56804456 0.92559664]]
[[ 0.79172504 0.52889492]
[ 0.07103606 0.0871293 ]]]]]]
# stride = 2,1
[[[[ 0.5488135 0.71518937]
[ 0.4236548 0.64589411]]
[[ 0.71518937 0.60276338]
[ 0.64589411 0.43758721]]
[[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]]]
[[[ 0.96366276 0.38344152]
[ 0.56804456 0.92559664]]
[[ 0.38344152 0.79172504]
[ 0.92559664 0.07103606]]
[[ 0.79172504 0.52889492]
[ 0.07103606 0.0871293 ]]]]