我有一批大小的分割图像
seg
--> [batch, channels, imsize, imgsize]
-->[16, 6, 50, 50]
这个张量中的每个标量都指定了一个分割类。我们有2000
总的分割类。
现在的目标是转换
[16, 6, 50, 50]
-->[16, 2000, 50, 50]
每个类都以一种热门方式编码。
我如何使用 pytorch api 做到这一点?我只能想到可笑的低效循环构造。
例子
在这里,我们将只有 2 个初始通道(而不是 6 个)、4 个标签(而不是 2000)、大小批量 1(而不是 16)和 4x4 图像而不是 50x50。
0, 0, 1, 1
0, 0, 0, 1
1, 1, 1, 1
1, 1, 1, 1
3, 3, 2, 2
3, 3, 2, 2
3, 3, 2, 2
3, 3, 2, 2
现在这变成了4通道输出
1, 1, 0, 0
1, 1, 1, 0
0, 0, 0, 0
0, 0, 0, 0
0, 0, 1, 1
0, 0, 0, 1
1, 1, 1, 1
1, 1, 1, 1
1, 1, 0, 0
1, 1, 0, 0
1, 1, 0, 0
1, 1, 0, 0
0, 0, 1, 1
0, 0, 1, 1
0, 0, 1, 1
0, 0, 1, 1
关键观察是特定标签仅出现在单个输入通道上。