0

我正在处理 172x220x156 形状的 3D 图像。要将图像输入网络进行输出,我需要从图像中提取大小为 32x32x32 的补丁,然后将其添加回来以再次获取图像。由于我的图像尺寸不是补丁大小的倍数,因此我必须得到重叠的补丁。我想知道该怎么做。

我在 PyTorch 工作,有一些选项unfoldfold但我不确定它们是如何工作的。

4

3 回答 3

0

您可以使用unfoldpytorch 文档):

batch_size, n_channels, n_rows, n_cols = 1, 172, 220, 156
x = torch.arange(batch_size*n_channels*n_rows*n_cols).view(batch_size, n_channels, n_rows, n_cols)

kernel_c, kernel_h, kernel_w = 32, 32, 32
step = 32

# Tensor.unfold(dimension, size, step)
windows_unpacked = x.unfold(1, kernel_c, step).unfold(2, kernel_h, step).unfold(3, kernel_w, step)
print(windows_unpacked.shape)
# result: torch.Size([1, 5, 6, 4, 32, 32, 32])

windows = windows_unpacked.permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, kernel_c, kernel_h, kernel_w)
print(windows.shape)
# result: torch.Size([120, 32, 32, 32])
于 2020-09-16T15:12:44.520 回答
0

要提取(重叠)补丁并重建输入形状,我们可以使用torch.nn.functional.unfold和 逆运算torch.nn.functional.fold。这些方法仅处理 4D 张量或 2D 图像,但是您可以使用这些方法一次处理一个维度。

几点注意事项:

  1. 这种方式需要pytorch的 fold/unfold 方法,不幸的是我还没有在 TF api 中找到类似的方法。

  2. 我们可以通过两种方式提取补丁,它们的输出是相同的。这些方法被调用extract_patches_3dextract_patches_3ds其中 X 是维数。后者使用 torch.Tensor.unfold() 并且代码行数更少。(输出是一样的,除了不能使用膨胀)

  3. 方法extract_patches_Xdcombine_patches_Xd逆向方法,并且组合器逐步反转提取器的步骤。

  4. 代码行后面是说明维度的注释,例如 (B, C, D, H, W)。使用以下内容:

    1. B:批量大小
    2. C: 频道
    3. D: 深度维度
    4. H: 高度尺寸
    5. W: 宽度尺寸
    6. x_dim_in:在提取方法中,这是维度中的输入像素数x。在组合方法中,这是维度中滑动窗口的数量x
    7. x_dim_out:在提取方法中,这是维度中滑动窗口的数量x。在组合方法中,这是维度中的输出像素数x
  5. 我有一个公共笔记本来试用代码

  6. get_dim_blocks()方法是pytorch 文档网站上给出的计算卷积层输出形状的函数。

  7. 请注意,如果您有重叠的补丁并将它们组合起来,重叠的元素将被求和。如果您想再次获得初始输入,有一种方法。

    1. 创建与带有 的补丁相似大小的张量torch.ones_like(patches_tensor)
    2. 将补丁组合成具有相同输出形状的完整图像。(这会为重叠元素创建一个计数器)。
    3. 将组合图像与组合图像相除,这应该反转任何元素的双重求和。(3D):我们需要使用 2fold并且unfold我们首先将 应用于fold维度D并通过将 kernel 设置为 1、padding 为 0、stride 为 1 和 dilation 为 1 来保持WH保持不变。在我们查看张量并折叠HW尺寸。展开反向进行,从H和开始W,然后D
def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding, padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)

    channels = x.shape[1]

    x = torch.nn.functional.pad(x, padding)
    # (B, C, D, H, W)
    x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
    # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
    return x

def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]

    d_dim_in = x.shape[2]
    h_dim_in = x.shape[3]
    w_dim_in = x.shape[4]
    d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
    h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
    w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
    # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
    
    # (B, C, D, H, W)
    x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)                                                     
    # (B, C, D, H * W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))                   
    # (B, C * kernel_size[0], d_dim_out * H * W)

    x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)                                   
    # (B, C * kernel_size[0] * d_dim_out, H, W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))        
    # (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)

    x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  
    # (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  

    x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
    # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])

    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])

    return x



def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]
    d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
    d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
    h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
    w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
    # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)

    x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])

    x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
    # (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)

    x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
    # (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)

    x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
    # (B, C * kernel_size[0] * d_dim_in, H, W)

    x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
    # (B, C * kernel_size[0], d_dim_in * H * W)

    x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
    # (B, C, D, H * W)
    
    x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
    # (B, C, D, H, W)

    return x

a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
b = extract_patches_3d(a, 2, padding=1, stride=1)
bs = extract_patches_3ds(a, 2, padding=1, stride=1)
print(b.shape)
print(b)
c = combine_patches_3d(b, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(c.shape)
print(c)
ones = torch.ones_like(b)
ones = combine_patches_3d(ones, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(torch.all(a==c))
print(c.shape, ones.shape)
d = c / ones
print(d)
print(torch.all(a==d))

输出(3D)

torch.Size([2, 2, 2, 4, 4])
tensor([[[[[  1.,   2.,   3.,   4.],
           [  5.,   6.,   7.,   8.],
           [  9.,  10.,  11.,  12.],
           [ 13.,  14.,  15.,  16.]],

          [[ 17.,  18.,  19.,  20.],
           [ 21.,  22.,  23.,  24.],
           [ 25.,  26.,  27.,  28.],
           [ 29.,  30.,  31.,  32.]]],


         [[[ 33.,  34.,  35.,  36.],
           [ 37.,  38.,  39.,  40.],
           [ 41.,  42.,  43.,  44.],
           [ 45.,  46.,  47.,  48.]],

          [[ 49.,  50.,  51.,  52.],
           [ 53.,  54.,  55.,  56.],
           [ 57.,  58.,  59.,  60.],
           [ 61.,  62.,  63.,  64.]]]],



        [[[[ 65.,  66.,  67.,  68.],
           [ 69.,  70.,  71.,  72.],
           [ 73.,  74.,  75.,  76.],
           [ 77.,  78.,  79.,  80.]],

          [[ 81.,  82.,  83.,  84.],
           [ 85.,  86.,  87.,  88.],
           [ 89.,  90.,  91.,  92.],
           [ 93.,  94.,  95.,  96.]]],


         [[[ 97.,  98.,  99., 100.],
           [101., 102., 103., 104.],
           [105., 106., 107., 108.],
           [109., 110., 111., 112.]],

          [[113., 114., 115., 116.],
           [117., 118., 119., 120.],
           [121., 122., 123., 124.],
           [125., 126., 127., 128.]]]]])
torch.Size([150, 2, 2, 2, 2])
tensor([[[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   1.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  1.,   2.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  2.,   3.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  3.,   4.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  4.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   1.],
           [  0.,   5.]]]],



        ...,



        [[[[124.,   0.],
           [128.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0., 125.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[125., 126.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[126., 127.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[127., 128.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[128.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]]])
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[   8.,   16.,   24.,   32.],
           [  40.,   48.,   56.,   64.],
           [  72.,   80.,   88.,   96.],
           [ 104.,  112.,  120.,  128.]],

          [[ 136.,  144.,  152.,  160.],
           [ 168.,  176.,  184.,  192.],
           [ 200.,  208.,  216.,  224.],
           [ 232.,  240.,  248.,  256.]]],


         [[[ 264.,  272.,  280.,  288.],
           [ 296.,  304.,  312.,  320.],
           [ 328.,  336.,  344.,  352.],
           [ 360.,  368.,  376.,  384.]],

          [[ 392.,  400.,  408.,  416.],
           [ 424.,  432.,  440.,  448.],
           [ 456.,  464.,  472.,  480.],
           [ 488.,  496.,  504.,  512.]]]],



        [[[[ 520.,  528.,  536.,  544.],
           [ 552.,  560.,  568.,  576.],
           [ 584.,  592.,  600.,  608.],
           [ 616.,  624.,  632.,  640.]],

          [[ 648.,  656.,  664.,  672.],
           [ 680.,  688.,  696.,  704.],
           [ 712.,  720.,  728.,  736.],
           [ 744.,  752.,  760.,  768.]]],


         [[[ 776.,  784.,  792.,  800.],
           [ 808.,  816.,  824.,  832.],
           [ 840.,  848.,  856.,  864.],
           [ 872.,  880.,  888.,  896.]],

          [[ 904.,  912.,  920.,  928.],
           [ 936.,  944.,  952.,  960.],
           [ 968.,  976.,  984.,  992.],
           [1000., 1008., 1016., 1024.]]]]])
tensor(False)
torch.Size([2, 2, 2, 4, 4]) torch.Size([2, 2, 2, 4, 4])
tensor([[[[[  1.,   2.,   3.,   4.],
           [  5.,   6.,   7.,   8.],
           [  9.,  10.,  11.,  12.],
           [ 13.,  14.,  15.,  16.]],

          [[ 17.,  18.,  19.,  20.],
           [ 21.,  22.,  23.,  24.],
           [ 25.,  26.,  27.,  28.],
           [ 29.,  30.,  31.,  32.]]],


         [[[ 33.,  34.,  35.,  36.],
           [ 37.,  38.,  39.,  40.],
           [ 41.,  42.,  43.,  44.],
           [ 45.,  46.,  47.,  48.]],

          [[ 49.,  50.,  51.,  52.],
           [ 53.,  54.,  55.,  56.],
           [ 57.,  58.,  59.,  60.],
           [ 61.,  62.,  63.,  64.]]]],



        [[[[ 65.,  66.,  67.,  68.],
           [ 69.,  70.,  71.,  72.],
           [ 73.,  74.,  75.,  76.],
           [ 77.,  78.,  79.,  80.]],

          [[ 81.,  82.,  83.,  84.],
           [ 85.,  86.,  87.,  88.],
           [ 89.,  90.,  91.,  92.],
           [ 93.,  94.,  95.,  96.]]],


         [[[ 97.,  98.,  99., 100.],
           [101., 102., 103., 104.],
           [105., 106., 107., 108.],
           [109., 110., 111., 112.]],

          [[113., 114., 115., 116.],
           [117., 118., 119., 120.],
           [121., 122., 123., 124.],
           [125., 126., 127., 128.]]]]])
tensor(True)

于 2021-07-05T12:18:37.880 回答
0

您的所有数据都准确172x220x156吗?如果是这样,似乎您可以只使用 for 循环并索引张量来获取32x32x32块,对吗?(可能硬编码一些东西)。

但是,我无法完全回答这个问题,因为不清楚你想如何组合结果。说清楚,这是你的目标吗?

1)32x32x32从图像中获取补丁 2) 对其进行一些任意处理 3) 将该补丁保存到result正确索引处的某个位置 4) 重复

如果是这样,您打算如何组合重叠的补丁?总结他们?平均他们?

但是 - 索引:

out_tensor = torch.zeros_like(input)
for i_idx in [0, 32, 64, 96, 128, 140]:
    for j_idx in [0, 32, 64, 96, 128, 160, 188]:
        for k_idx in [0, 32, 64, 96, 124]:
            input = tensor[i_idx, j_idx, k_idx]
            output = your_model(input)
            out_tensor[i_idx, j_idx, k_idx] = output

这根本没有优化,但我想大部分计算将是实际的神经网络,而且没有办法解决这个问题,所以优化可能毫无意义。

于 2019-07-22T05:48:02.650 回答