我正在处理 172x220x156 形状的 3D 图像。要将图像输入网络进行输出,我需要从图像中提取大小为 32x32x32 的补丁,然后将其添加回来以再次获取图像。由于我的图像尺寸不是补丁大小的倍数,因此我必须得到重叠的补丁。我想知道该怎么做。
我在 PyTorch 工作,有一些选项unfold
,fold
但我不确定它们是如何工作的。
您可以使用unfold
(pytorch 文档):
batch_size, n_channels, n_rows, n_cols = 1, 172, 220, 156
x = torch.arange(batch_size*n_channels*n_rows*n_cols).view(batch_size, n_channels, n_rows, n_cols)
kernel_c, kernel_h, kernel_w = 32, 32, 32
step = 32
# Tensor.unfold(dimension, size, step)
windows_unpacked = x.unfold(1, kernel_c, step).unfold(2, kernel_h, step).unfold(3, kernel_w, step)
print(windows_unpacked.shape)
# result: torch.Size([1, 5, 6, 4, 32, 32, 32])
windows = windows_unpacked.permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, kernel_c, kernel_h, kernel_w)
print(windows.shape)
# result: torch.Size([120, 32, 32, 32])
要提取(重叠)补丁并重建输入形状,我们可以使用torch.nn.functional.unfold
和 逆运算torch.nn.functional.fold
。这些方法仅处理 4D 张量或 2D 图像,但是您可以使用这些方法一次处理一个维度。
几点注意事项:
这种方式需要pytorch的 fold/unfold 方法,不幸的是我还没有在 TF api 中找到类似的方法。
我们可以通过两种方式提取补丁,它们的输出是相同的。这些方法被调用extract_patches_3d
,extract_patches_3ds
其中 X 是维数。后者使用 torch.Tensor.unfold() 并且代码行数更少。(输出是一样的,除了不能使用膨胀)
方法extract_patches_Xd
和combine_patches_Xd
是逆向方法,并且组合器逐步反转提取器的步骤。
代码行后面是说明维度的注释,例如 (B, C, D, H, W)。使用以下内容:
B
:批量大小C
: 频道D
: 深度维度H
: 高度尺寸W
: 宽度尺寸x_dim_in
:在提取方法中,这是维度中的输入像素数x
。在组合方法中,这是维度中滑动窗口的数量x
。x_dim_out
:在提取方法中,这是维度中滑动窗口的数量x
。在组合方法中,这是维度中的输出像素数x
。我有一个公共笔记本来试用代码
该get_dim_blocks()
方法是pytorch 文档网站上给出的计算卷积层输出形状的函数。
请注意,如果您有重叠的补丁并将它们组合起来,重叠的元素将被求和。如果您想再次获得初始输入,有一种方法。
torch.ones_like(patches_tensor)
。fold
并且unfold
我们首先将 应用于fold
维度D
并通过将 kernel 设置为 1、padding 为 0、stride 为 1 和 dilation 为 1 来保持W
和H
保持不变。在我们查看张量并折叠H
和W
尺寸。展开反向进行,从H
和开始W
,然后D
。def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
channels = x.shape[1]
x = torch.nn.functional.pad(x, padding)
# (B, C, D, H, W)
x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_in = x.shape[2]
h_dim_in = x.shape[3]
w_dim_in = x.shape[4]
d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)
# (B, C, D, H * W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C * kernel_size[0], d_dim_out * H * W)
x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)
# (B, C * kernel_size[0] * d_dim_out, H, W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)
x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
# (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
# (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
# (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)
x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_in, H, W)
x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
# (B, C * kernel_size[0], d_dim_in * H * W)
x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C, D, H * W)
x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
return x
a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
b = extract_patches_3d(a, 2, padding=1, stride=1)
bs = extract_patches_3ds(a, 2, padding=1, stride=1)
print(b.shape)
print(b)
c = combine_patches_3d(b, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(c.shape)
print(c)
ones = torch.ones_like(b)
ones = combine_patches_3d(ones, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(torch.all(a==c))
print(c.shape, ones.shape)
d = c / ones
print(d)
print(torch.all(a==d))
输出(3D)
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[ 13., 14., 15., 16.]],
[[ 17., 18., 19., 20.],
[ 21., 22., 23., 24.],
[ 25., 26., 27., 28.],
[ 29., 30., 31., 32.]]],
[[[ 33., 34., 35., 36.],
[ 37., 38., 39., 40.],
[ 41., 42., 43., 44.],
[ 45., 46., 47., 48.]],
[[ 49., 50., 51., 52.],
[ 53., 54., 55., 56.],
[ 57., 58., 59., 60.],
[ 61., 62., 63., 64.]]]],
[[[[ 65., 66., 67., 68.],
[ 69., 70., 71., 72.],
[ 73., 74., 75., 76.],
[ 77., 78., 79., 80.]],
[[ 81., 82., 83., 84.],
[ 85., 86., 87., 88.],
[ 89., 90., 91., 92.],
[ 93., 94., 95., 96.]]],
[[[ 97., 98., 99., 100.],
[101., 102., 103., 104.],
[105., 106., 107., 108.],
[109., 110., 111., 112.]],
[[113., 114., 115., 116.],
[117., 118., 119., 120.],
[121., 122., 123., 124.],
[125., 126., 127., 128.]]]]])
torch.Size([150, 2, 2, 2, 2])
tensor([[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 1.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 1., 2.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 2., 3.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 3., 4.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 4., 0.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 1.],
[ 0., 5.]]]],
...,
[[[[124., 0.],
[128., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 0., 125.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[125., 126.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[126., 127.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[127., 128.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[128., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]]])
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 8., 16., 24., 32.],
[ 40., 48., 56., 64.],
[ 72., 80., 88., 96.],
[ 104., 112., 120., 128.]],
[[ 136., 144., 152., 160.],
[ 168., 176., 184., 192.],
[ 200., 208., 216., 224.],
[ 232., 240., 248., 256.]]],
[[[ 264., 272., 280., 288.],
[ 296., 304., 312., 320.],
[ 328., 336., 344., 352.],
[ 360., 368., 376., 384.]],
[[ 392., 400., 408., 416.],
[ 424., 432., 440., 448.],
[ 456., 464., 472., 480.],
[ 488., 496., 504., 512.]]]],
[[[[ 520., 528., 536., 544.],
[ 552., 560., 568., 576.],
[ 584., 592., 600., 608.],
[ 616., 624., 632., 640.]],
[[ 648., 656., 664., 672.],
[ 680., 688., 696., 704.],
[ 712., 720., 728., 736.],
[ 744., 752., 760., 768.]]],
[[[ 776., 784., 792., 800.],
[ 808., 816., 824., 832.],
[ 840., 848., 856., 864.],
[ 872., 880., 888., 896.]],
[[ 904., 912., 920., 928.],
[ 936., 944., 952., 960.],
[ 968., 976., 984., 992.],
[1000., 1008., 1016., 1024.]]]]])
tensor(False)
torch.Size([2, 2, 2, 4, 4]) torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[ 13., 14., 15., 16.]],
[[ 17., 18., 19., 20.],
[ 21., 22., 23., 24.],
[ 25., 26., 27., 28.],
[ 29., 30., 31., 32.]]],
[[[ 33., 34., 35., 36.],
[ 37., 38., 39., 40.],
[ 41., 42., 43., 44.],
[ 45., 46., 47., 48.]],
[[ 49., 50., 51., 52.],
[ 53., 54., 55., 56.],
[ 57., 58., 59., 60.],
[ 61., 62., 63., 64.]]]],
[[[[ 65., 66., 67., 68.],
[ 69., 70., 71., 72.],
[ 73., 74., 75., 76.],
[ 77., 78., 79., 80.]],
[[ 81., 82., 83., 84.],
[ 85., 86., 87., 88.],
[ 89., 90., 91., 92.],
[ 93., 94., 95., 96.]]],
[[[ 97., 98., 99., 100.],
[101., 102., 103., 104.],
[105., 106., 107., 108.],
[109., 110., 111., 112.]],
[[113., 114., 115., 116.],
[117., 118., 119., 120.],
[121., 122., 123., 124.],
[125., 126., 127., 128.]]]]])
tensor(True)
您的所有数据都准确172x220x156
吗?如果是这样,似乎您可以只使用 for 循环并索引张量来获取32x32x32
块,对吗?(可能硬编码一些东西)。
但是,我无法完全回答这个问题,因为不清楚你想如何组合结果。说清楚,这是你的目标吗?
1)32x32x32
从图像中获取补丁 2) 对其进行一些任意处理 3) 将该补丁保存到result
正确索引处的某个位置 4) 重复
如果是这样,您打算如何组合重叠的补丁?总结他们?平均他们?
但是 - 索引:
out_tensor = torch.zeros_like(input)
for i_idx in [0, 32, 64, 96, 128, 140]:
for j_idx in [0, 32, 64, 96, 128, 160, 188]:
for k_idx in [0, 32, 64, 96, 124]:
input = tensor[i_idx, j_idx, k_idx]
output = your_model(input)
out_tensor[i_idx, j_idx, k_idx] = output
这根本没有优化,但我想大部分计算将是实际的神经网络,而且没有办法解决这个问题,所以优化可能毫无意义。