我正在研究tensorflow2,我正在尝试使用索引实现 Max unpool 以实现 SegNet。
当我运行它时,我遇到了以下问题。我正在定义 def MaxUnpool2D
,然后在模型中调用它。我想问题是由于更新和掩码已经成形(无,H,W,ch)。
def MaxUnpooling2D(updates, mask):
size = 2
mask = tf.cast(mask, 'int32')
input_shape = tf.shape(updates, out_type='int32')
# calculation new shape
output_shape = (
input_shape[0],
input_shape[1]*size,
input_shape[2]*size,
input_shape[3])
# calculation indices for batch, height, width and feature maps
one_like_mask = tf.ones_like(mask, dtype='int32')
batch_shape = tf.concat(
[[input_shape[0]], [1], [1], [1]],
axis=0)
batch_range = tf.reshape(
tf.range(output_shape[0], dtype='int32'),
shape=batch_shape)
b = one_like_mask * batch_range
y = mask // (output_shape[2] * output_shape[3])
x = (mask // output_shape[3]) % output_shape[2]
feature_range = tf.range(output_shape[3], dtype='int32')
f = one_like_mask * feature_range
updates_size = tf.size(updates)
indices = K.transpose(K.reshape(
tf.stack([b, y, x, f]),
[4, updates_size]))
values = tf.reshape(updates, [updates_size])
return tf.scatter_nd(indices, values, output_shape)
def segnet_conv(
inputs,
kernel_size=3,
kernel_initializer='glorot_uniform',
batch_norm = False,
**kwargs):
conv1 = Conv2D(
filters=64,
kernel_size=kernel_size,
padding='same',
activation=None,
kernel_initializer=kernel_initializer,
name='conv_1'
)(inputs)
if batch_norm:
conv1 = BatchNormalization(name='bn_1')(conv1)
conv1 = LeakyReLU(alpha=0.3, name='activation_1')(conv1)
conv1 = Conv2D(
filters=64,
kernel_size=kernel_size,
padding='same',
activation=None,
kernel_initializer=kernel_initializer,
name='conv_2'
)(conv1)
if batch_norm:
conv1 = BatchNormalization(name='bn_2')(conv1)
conv1 = LeakyReLU(alpha=0.3, name='activation_2')(conv1)
pool1, mask1 = tf.nn.max_pool_with_argmax(
input=conv1,
ksize=2,
strides=2,
padding='SAME'
)
def segnet_deconv(
pool1,
mask1,
kernel_size=3,
kernel_initializer='glorot_uniform',
batch_norm = False,
**kwargs
):
dec = MaxUnpooling2D(pool5, mask5)
dec = Conv2D(
filters=512,
kernel_size=kernel_size,
padding='same',
activation=None,
kernel_initializer=kernel_initializer,
name='upconv_13'
)(dec)
def classifier(
dec,
ch_out=2,
kernel_size=3,
final_activation=None,
batch_norm = False,
**kwargs
):
dec = Conv2D(
filters=64,
kernel_size=kernel_size,
activation='relu',
padding='same',
name='dec_out1'
)(dec)
@tf.function
def segnet(
inputs,
ch_out=2,
kernel_size=3,
kernel_initializer='glorot_uniform',
final_activation=None,
batch_norm = False,
**kwargs
):
pool5, mask1, mask2, mask3, mask4, mask5 = segnet_conv(
inputs,
kernel_size=3,
kernel_initializer='glorot_uniform',
batch_norm = False
)
dec = segnet_deconv(
pool5,
mask1,
mask2,
mask3,
mask4,
mask5,
kernel_size=kernel_size,
kernel_initializer=kernel_initializer,
batch_norm = batch_norm
)
output = classifier(
dec,
ch_out=2,
kernel_size=3,
final_activation=None,
batch_norm = batch_norm
)
return output
inputs = Input(shape=(*params['image_size'], params['num_channels']), name='input')
outputs = segnet(inputs, n_labels=2, kernel=3, pool_size=(2, 2), output_mode=None)
# we define our U-Net to output logits
model = Model(inputs, outputs)
你能帮我解决这个问题吗?